global_exchange.cpp 5.21 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
#include "global_exchange.h"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>

#ifdef FMOE_USE_NCCL
#include <nccl.h>

Rick Ho's avatar
Rick Ho committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21

void fmoe_cuda_expert_exchange_impl(
        const long* local_expert_count,
        long* global_expert_count,
        int n_expert, int world_size,
        CudaStreamManager* smgr) {
    NCCL_SAFE_CALL(ncclGroupStart());
    for (int i = 0; i < world_size; ++i) {
        NCCL_SAFE_CALL(ncclSend(
                local_expert_count + n_expert * i,
                n_expert,
                ncclInt64,
                i,
                smgr->ncclcomm,
Rick Ho's avatar
Rick Ho committed
22
                smgr->torchStream()));
Rick Ho's avatar
Rick Ho committed
23
24
25
26
27
28
        NCCL_SAFE_CALL(ncclRecv(
                global_expert_count + n_expert * i,
                n_expert,
                ncclInt64,
                i,
                smgr->ncclcomm,
Rick Ho's avatar
Rick Ho committed
29
                smgr->torchStream()));
Rick Ho's avatar
Rick Ho committed
30
31
32
33
    }
    NCCL_SAFE_CALL(ncclGroupEnd());
}

34
torch::Tensor _expert_exchange(
Rick Ho's avatar
Rick Ho committed
35
        torch::Tensor local_expert_count,
Rick Ho's avatar
Rick Ho committed
36
        long n_expert, long n_workers) {
Rick Ho's avatar
Rick Ho committed
37
38
39
40
41
42
    auto global_expert_count = torch::empty_like(local_expert_count);
    auto smgr = getCudaStreamManager(local_expert_count.device().index());

    fmoe_cuda_expert_exchange_impl(
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
43
            n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
44
            smgr);
45
    return global_expert_count;
Rick Ho's avatar
Rick Ho committed
46
47
}

48
torch::Tensor _global_scatter(
Rick Ho's avatar
Rick Ho committed
49
50
51
52
53
54
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers) {
    CHECK_INPUT(input_buf);

Rick Ho's avatar
Rick Ho committed
55
    auto n_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
56
57
58
59
    auto in_feat = input_buf.size(1);
    auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
    auto smgr = getCudaStreamManager(input_buf.device().index());

Rick Ho's avatar
Rick Ho committed
60
61
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            input_buf.scalar_type(), "fmoe_cuda_global_scatter", ([&] {
Rick Ho's avatar
Rick Ho committed
62
63
64
65
66
        fmoe_cuda_global_scatter_impl<scalar_t>(
            input_buf.data_ptr<scalar_t>(),
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            global_input_buf.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
67
            in_feat, n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
68
69
70
            smgr
        );
    }));
71
    return global_input_buf;
Rick Ho's avatar
Rick Ho committed
72
73
}

74
torch::Tensor _global_gather(
Rick Ho's avatar
Rick Ho committed
75
76
77
78
79
80
        torch::Tensor output_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers) {
    CHECK_INPUT(output_buf);

Rick Ho's avatar
Rick Ho committed
81
    auto n_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
82
83
84
85
    auto out_feat = output_buf.size(1);
    auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
    auto smgr = getCudaStreamManager(output_buf.device().index());

Rick Ho's avatar
Rick Ho committed
86
87
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            output_buf.scalar_type(), "fmoe_cuda_global_gather", ([&] {
Rick Ho's avatar
Rick Ho committed
88
89
90
91
92
        fmoe_cuda_global_gather_impl<scalar_t>(
            output_buf.data_ptr<scalar_t>(),
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            local_output_buf.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
93
            out_feat, n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
94
95
96
            smgr
        );
    }));
97
    return local_output_buf;
Rick Ho's avatar
Rick Ho committed
98
99
}

Rick Ho's avatar
Rick Ho committed
100
101
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
        (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
102
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
Rick Ho's avatar
Rick Ho committed
103
104
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else
Rick Ho's avatar
Rick Ho committed
105
#include <c10d/ProcessGroupNCCL.hpp>
Rick Ho's avatar
Rick Ho committed
106
#endif
Rick Ho's avatar
Rick Ho committed
107
108
109
110
111
112
113
114
115

class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
    ncclComm_t getcomm(at::Device dev) {
        ncclUniqueId ncclID;
        int rank = getRank();
        if (rank == 0) {
            ncclGetUniqueId(&ncclID);
        }
Rich Ho's avatar
Rich Ho committed
116
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
Rick Ho's avatar
Rick Ho committed
117
118
119
120
121
122
        (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 12))
        broadcastUniqueNCCLID(&ncclID,
                false,
                "fastmoe_nccl_comm",
                rank);
#elif defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
Rick Ho's avatar
Rick Ho committed
123
        (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 8))
Rick Ho's avatar
Rick Ho committed
124
125
126
127
        broadcastUniqueNCCLID(&ncclID,
                c10d::OpType::SEND,
                "fastmoe_nccl_comm",
                rank);
Rick Ho's avatar
Rick Ho committed
128
129
130
#else
        broadcastUniqueNCCLID(&ncclID);
#endif
Rick Ho's avatar
Rick Ho committed
131
        ncclComm_t comm;
132
        NCCL_SAFE_CALL(ncclCommInitRank(&comm, getSize(), ncclID, rank));
Rick Ho's avatar
Rick Ho committed
133
134
135
136
        return comm;
    }
};

137
138
139
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void _ensure_nccl(c10d::ProcessGroup& p, torch::Tensor t) {
#else
Rick Ho's avatar
Rick Ho committed
140
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
141
#endif  // TORCH_VERSION
Rick Ho's avatar
Rick Ho committed
142
143
144
145
    auto smgr = getCudaStreamManager(t.device().index());
    if (smgr->ncclgood) {
        return;
    }
146
147
148
149
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
    HackNCCLGroup* h = (HackNCCLGroup*)(void*)
        (p.getBackend(c10d::ProcessGroup::NCCL).get());
#else
Rick Ho's avatar
Rick Ho committed
150
    HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
151
#endif  // TORCH_VERSION
Rick Ho's avatar
Rick Ho committed
152
153
154
155
156
157
158
159
160
    smgr->ncclcomm = h->getcomm(t.device());
    if (smgr->ncclcomm != 0) {
        smgr->ncclgood = 1;
    } else {
        std::cerr << "Nccl initialization failed\n";
    }
}

#endif  // FMOE_USE_NCCL