global_exchange.cpp 5.22 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
#include "global_exchange.h"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>

#ifdef FMOE_USE_NCCL
#include <nccl.h>

Rick Ho's avatar
Rick Ho committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

void fmoe_cuda_expert_exchange_impl(
        const long* local_expert_count,
        long* global_expert_count,
        int n_expert, int world_size,
        CudaStreamManager* smgr) {
    NCCL_SAFE_CALL(ncclGroupStart());
    for (int i = 0; i < world_size; ++i) {
        NCCL_SAFE_CALL(ncclSend(
                local_expert_count + n_expert * i,
                n_expert,
                ncclInt64,
                i,
                smgr->ncclcomm,
                smgr->stream(0)));
        NCCL_SAFE_CALL(ncclRecv(
                global_expert_count + n_expert * i,
                n_expert,
                ncclInt64,
                i,
                smgr->ncclcomm,
                smgr->stream(0)));
    }
    NCCL_SAFE_CALL(ncclGroupEnd());
    smgr->sync(1);
}

35
torch::Tensor _expert_exchange(
Rick Ho's avatar
Rick Ho committed
36
        torch::Tensor local_expert_count,
Rick Ho's avatar
Rick Ho committed
37
        long n_expert, long n_workers) {
Rick Ho's avatar
Rick Ho committed
38
39
40
41
42
43
    auto global_expert_count = torch::empty_like(local_expert_count);
    auto smgr = getCudaStreamManager(local_expert_count.device().index());

    fmoe_cuda_expert_exchange_impl(
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
44
            n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
45
            smgr);
46
    return global_expert_count;
Rick Ho's avatar
Rick Ho committed
47
48
}

49
torch::Tensor _global_scatter(
Rick Ho's avatar
Rick Ho committed
50
51
52
53
54
55
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers) {
    CHECK_INPUT(input_buf);

Rick Ho's avatar
Rick Ho committed
56
    auto n_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
57
58
59
60
    auto in_feat = input_buf.size(1);
    auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
    auto smgr = getCudaStreamManager(input_buf.device().index());

Rick Ho's avatar
Rick Ho committed
61
62
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            input_buf.scalar_type(), "fmoe_cuda_global_scatter", ([&] {
Rick Ho's avatar
Rick Ho committed
63
64
65
66
67
        fmoe_cuda_global_scatter_impl<scalar_t>(
            input_buf.data_ptr<scalar_t>(),
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            global_input_buf.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
68
            in_feat, n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
69
70
71
            smgr
        );
    }));
72
    return global_input_buf;
Rick Ho's avatar
Rick Ho committed
73
74
}

75
torch::Tensor _global_gather(
Rick Ho's avatar
Rick Ho committed
76
77
78
79
80
81
        torch::Tensor output_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers) {
    CHECK_INPUT(output_buf);

Rick Ho's avatar
Rick Ho committed
82
    auto n_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
83
84
85
86
    auto out_feat = output_buf.size(1);
    auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
    auto smgr = getCudaStreamManager(output_buf.device().index());

Rick Ho's avatar
Rick Ho committed
87
88
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            output_buf.scalar_type(), "fmoe_cuda_global_gather", ([&] {
Rick Ho's avatar
Rick Ho committed
89
90
91
92
93
        fmoe_cuda_global_gather_impl<scalar_t>(
            output_buf.data_ptr<scalar_t>(),
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            local_output_buf.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
94
            out_feat, n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
95
96
97
            smgr
        );
    }));
98
    return local_output_buf;
Rick Ho's avatar
Rick Ho committed
99
100
}

Rick Ho's avatar
Rick Ho committed
101
102
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
        (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
103
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
Rick Ho's avatar
Rick Ho committed
104
105
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else
Rick Ho's avatar
Rick Ho committed
106
#include <c10d/ProcessGroupNCCL.hpp>
Rick Ho's avatar
Rick Ho committed
107
#endif
Rick Ho's avatar
Rick Ho committed
108
109
110
111
112
113
114
115
116

class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
    ncclComm_t getcomm(at::Device dev) {
        ncclUniqueId ncclID;
        int rank = getRank();
        if (rank == 0) {
            ncclGetUniqueId(&ncclID);
        }
Rich Ho's avatar
Rich Ho committed
117
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
Rick Ho's avatar
Rick Ho committed
118
119
120
121
122
123
        (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 12))
        broadcastUniqueNCCLID(&ncclID,
                false,
                "fastmoe_nccl_comm",
                rank);
#elif defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
Rick Ho's avatar
Rick Ho committed
124
        (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 8))
Rick Ho's avatar
Rick Ho committed
125
126
127
128
        broadcastUniqueNCCLID(&ncclID,
                c10d::OpType::SEND,
                "fastmoe_nccl_comm",
                rank);
Rick Ho's avatar
Rick Ho committed
129
130
131
#else
        broadcastUniqueNCCLID(&ncclID);
#endif
Rick Ho's avatar
Rick Ho committed
132
        ncclComm_t comm;
133
        NCCL_SAFE_CALL(ncclCommInitRank(&comm, getSize(), ncclID, rank));
Rick Ho's avatar
Rick Ho committed
134
135
136
137
        return comm;
    }
};

138
139
140
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void _ensure_nccl(c10d::ProcessGroup& p, torch::Tensor t) {
#else
Rick Ho's avatar
Rick Ho committed
141
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
142
#endif  // TORCH_VERSION
Rick Ho's avatar
Rick Ho committed
143
144
145
146
    auto smgr = getCudaStreamManager(t.device().index());
    if (smgr->ncclgood) {
        return;
    }
147
148
149
150
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
    HackNCCLGroup* h = (HackNCCLGroup*)(void*)
        (p.getBackend(c10d::ProcessGroup::NCCL).get());
#else
Rick Ho's avatar
Rick Ho committed
151
    HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
152
#endif  // TORCH_VERSION
Rick Ho's avatar
Rick Ho committed
153
154
155
156
157
158
159
160
161
    smgr->ncclcomm = h->getcomm(t.device());
    if (smgr->ncclcomm != 0) {
        smgr->ncclgood = 1;
    } else {
        std::cerr << "Nccl initialization failed\n";
    }
}

#endif  // FMOE_USE_NCCL