global_exchange.cpp 3.38 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
#include "global_exchange.h"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>

#ifdef FMOE_USE_NCCL
#include <nccl.h>

std::vector<torch::Tensor> _expert_exchange(
        torch::Tensor local_expert_count,
Rick Ho's avatar
Rick Ho committed
10
        long n_expert, long n_workers) {
Rick Ho's avatar
Rick Ho committed
11
12
13
14
15
16
    auto global_expert_count = torch::empty_like(local_expert_count);
    auto smgr = getCudaStreamManager(local_expert_count.device().index());

    fmoe_cuda_expert_exchange_impl(
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
17
            n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
18
19
20
21
22
23
24
25
26
27
28
            smgr);
    return {global_expert_count};
}

std::vector<torch::Tensor> _global_scatter(
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers) {
    CHECK_INPUT(input_buf);

Rick Ho's avatar
Rick Ho committed
29
    auto n_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
30
31
32
33
34
35
36
37
38
39
40
    auto in_feat = input_buf.size(1);
    auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
    auto smgr = getCudaStreamManager(input_buf.device().index());

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), 
            "fmoe_cuda_global_scatter", ([&] {
        fmoe_cuda_global_scatter_impl<scalar_t>(
            input_buf.data_ptr<scalar_t>(),
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            global_input_buf.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
41
            in_feat, n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
42
43
44
45
46
47
48
49
50
51
52
53
54
            smgr
        );
    }));
    return {global_input_buf,};
}

std::vector<torch::Tensor> _global_gather(
        torch::Tensor output_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers) {
    CHECK_INPUT(output_buf);

Rick Ho's avatar
Rick Ho committed
55
    auto n_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
56
57
58
59
60
61
62
63
64
65
66
    auto out_feat = output_buf.size(1);
    auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
    auto smgr = getCudaStreamManager(output_buf.device().index());

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), 
            "fmoe_cuda_global_gather", ([&] {
        fmoe_cuda_global_gather_impl<scalar_t>(
            output_buf.data_ptr<scalar_t>(),
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            local_output_buf.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
67
            out_feat, n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            smgr
        );
    }));
    return {local_output_buf,};
}

#include <c10d/ProcessGroupNCCL.hpp>

class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
    ncclComm_t getcomm(at::Device dev) {
        ncclUniqueId ncclID;
        int rank = getRank();
        if (rank == 0) {
            ncclGetUniqueId(&ncclID);
        }
        broadcastUniqueNCCLID(&ncclID,
                c10d::OpType::SEND,
                "fastmoe_nccl_comm",
                rank);
        ncclComm_t comm;
        ncclCommInitRank(&comm, getSize(), ncclID, rank);
        return comm;
    }
};

void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
    auto smgr = getCudaStreamManager(t.device().index());
    if (smgr->ncclgood) {
        return;
    }
    HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
    smgr->ncclcomm = h->getcomm(t.device());
    if (smgr->ncclcomm != 0) {
        smgr->ncclgood = 1;
    } else {
        std::cerr << "Nccl initialization failed\n";
    }
}

#endif  // FMOE_USE_NCCL