global_exchange.cpp 3.35 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
#include "global_exchange.h"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>

#ifdef FMOE_USE_NCCL
#include <nccl.h>

8
torch::Tensor _expert_exchange(
Rick Ho's avatar
Rick Ho committed
9
        torch::Tensor local_expert_count,
Rick Ho's avatar
Rick Ho committed
10
        long n_expert, long n_workers) {
Rick Ho's avatar
Rick Ho committed
11
12
13
14
15
16
    auto global_expert_count = torch::empty_like(local_expert_count);
    auto smgr = getCudaStreamManager(local_expert_count.device().index());

    fmoe_cuda_expert_exchange_impl(
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
17
            n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
18
            smgr);
19
    return global_expert_count;
Rick Ho's avatar
Rick Ho committed
20
21
}

22
torch::Tensor _global_scatter(
Rick Ho's avatar
Rick Ho committed
23
24
25
26
27
28
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers) {
    CHECK_INPUT(input_buf);

Rick Ho's avatar
Rick Ho committed
29
    auto n_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
30
31
32
33
34
35
36
37
38
39
40
    auto in_feat = input_buf.size(1);
    auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
    auto smgr = getCudaStreamManager(input_buf.device().index());

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), 
            "fmoe_cuda_global_scatter", ([&] {
        fmoe_cuda_global_scatter_impl<scalar_t>(
            input_buf.data_ptr<scalar_t>(),
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            global_input_buf.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
41
            in_feat, n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
42
43
44
            smgr
        );
    }));
45
    return global_input_buf;
Rick Ho's avatar
Rick Ho committed
46
47
}

48
torch::Tensor _global_gather(
Rick Ho's avatar
Rick Ho committed
49
50
51
52
53
54
        torch::Tensor output_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers) {
    CHECK_INPUT(output_buf);

Rick Ho's avatar
Rick Ho committed
55
    auto n_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
56
57
58
59
60
61
62
63
64
65
66
    auto out_feat = output_buf.size(1);
    auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
    auto smgr = getCudaStreamManager(output_buf.device().index());

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), 
            "fmoe_cuda_global_gather", ([&] {
        fmoe_cuda_global_gather_impl<scalar_t>(
            output_buf.data_ptr<scalar_t>(),
            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            local_output_buf.data_ptr<scalar_t>(),
Rick Ho's avatar
Rick Ho committed
67
            out_feat, n_expert, n_workers,
Rick Ho's avatar
Rick Ho committed
68
69
70
            smgr
        );
    }));
71
    return local_output_buf;
Rick Ho's avatar
Rick Ho committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
}

#include <c10d/ProcessGroupNCCL.hpp>

class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
    ncclComm_t getcomm(at::Device dev) {
        ncclUniqueId ncclID;
        int rank = getRank();
        if (rank == 0) {
            ncclGetUniqueId(&ncclID);
        }
        broadcastUniqueNCCLID(&ncclID,
                c10d::OpType::SEND,
                "fastmoe_nccl_comm",
                rank);
        ncclComm_t comm;
89
        NCCL_SAFE_CALL(ncclCommInitRank(&comm, getSize(), ncclID, rank));
Rick Ho's avatar
Rick Ho committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        return comm;
    }
};

void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
    auto smgr = getCudaStreamManager(t.device().index());
    if (smgr->ncclgood) {
        return;
    }
    HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
    smgr->ncclcomm = h->getcomm(t.device());
    if (smgr->ncclcomm != 0) {
        smgr->ncclgood = 1;
    } else {
        std::cerr << "Nccl initialization failed\n";
    }
}

#endif  // FMOE_USE_NCCL