global_exchange.h 4.12 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
#include "stream_manager.h"
#ifdef FMOE_USE_NCCL

void fmoe_cuda_expert_exchange_impl(
        const long* local_expert_count, 
        long* global_expert_count, 
Rick Ho's avatar
Rick Ho committed
7
        int n_expert, int world_size,
Rick Ho's avatar
Rick Ho committed
8
9
10
11
        CudaStreamManager* smgr) {
    NCCL_SAFE_CALL(ncclGroupStart());
    for (int i = 0; i < world_size; ++i) {
        NCCL_SAFE_CALL(ncclSend(
Rick Ho's avatar
Rick Ho committed
12
13
                local_expert_count + n_expert * i,
                n_expert,
Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
                ncclInt64,
                i,
                smgr->ncclcomm,
                smgr->stream(0)));
        NCCL_SAFE_CALL(ncclRecv(
Rick Ho's avatar
Rick Ho committed
19
20
                global_expert_count + n_expert * i,
                n_expert,
Rick Ho's avatar
Rick Ho committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
                ncclInt64,
                i,
                smgr->ncclcomm,
                smgr->stream(0)));
    }
    NCCL_SAFE_CALL(ncclGroupEnd());
    smgr->sync(1);
}

template<typename scalar_t>
void fmoe_cuda_global_scatter_impl(
    const scalar_t* local_input_buf,
    const long* local_expert_count,
    const long* global_expert_count,
    scalar_t* input_buf,
Rick Ho's avatar
Rick Ho committed
36
    size_t in_feat, size_t n_expert, size_t world_size,
Rick Ho's avatar
Rick Ho committed
37
38
39
40
    CudaStreamManager* smgr) {
    // assert world_size > 1
    int recv_ptr = 0;
    /* TODO: may save for backward */
Rick Ho's avatar
Rick Ho committed
41
    long*expert_ptr = new long[n_expert * world_size];
Rick Ho's avatar
Rick Ho committed
42
    expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
43
    for (size_t i = 1; i < n_expert * world_size; ++i) {
Rick Ho's avatar
Rick Ho committed
44
45
46
        expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
    }

Rick Ho's avatar
Rick Ho committed
47
    for (size_t i = 0; i < n_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
48
49
        NCCL_SAFE_CALL(ncclGroupStart());
        for (size_t j = 0; j < world_size; ++j) {
Rick Ho's avatar
Rick Ho committed
50
            int idx = i + j * n_expert;
Rick Ho's avatar
Rick Ho committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            if (local_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclSend(
                        local_input_buf + expert_ptr[idx] * in_feat, 
                        local_expert_count[idx] * in_feat * sizeof(scalar_t),
                        ncclChar, 
                        j,
                        smgr->ncclcomm,
                        smgr->stream(0)));
            }
            if (global_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclRecv(
                        input_buf + recv_ptr * in_feat,
                        global_expert_count[idx] * in_feat * sizeof(scalar_t),
                        ncclChar,
                        j,
                        smgr->ncclcomm,
                        smgr->stream(0)));
                recv_ptr += global_expert_count[idx];
            }
        }
        NCCL_SAFE_CALL(ncclGroupEnd());
    }
    delete [] expert_ptr;
    smgr->sync(1);
}

template<typename scalar_t>
void fmoe_cuda_global_gather_impl(
    const scalar_t* output_buf,
    const long* local_expert_count,
    const long* global_expert_count,
    scalar_t* local_output_buf,
Rick Ho's avatar
Rick Ho committed
83
    size_t out_feat, size_t n_expert, size_t world_size,
Rick Ho's avatar
Rick Ho committed
84
85
86
    CudaStreamManager* smgr) {
    long send_ptr = 0;
    /* TODO: may save for backward */
Rick Ho's avatar
Rick Ho committed
87
    long *expert_ptr = new long[n_expert * world_size];
Rick Ho's avatar
Rick Ho committed
88
    expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
89
    for (size_t i = 1; i < n_expert * world_size; ++i) {
Rick Ho's avatar
Rick Ho committed
90
91
92
        expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
    }

Rick Ho's avatar
Rick Ho committed
93
    for (size_t i = 0; i < n_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
94
95
        NCCL_SAFE_CALL(ncclGroupStart());
        for (size_t j = 0; j < world_size; ++j) {
Rick Ho's avatar
Rick Ho committed
96
            int idx = i + j * n_expert;
Rick Ho's avatar
Rick Ho committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            if (global_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclSend(
                        output_buf + send_ptr * out_feat,
                        global_expert_count[idx] * out_feat * sizeof(scalar_t),
                        ncclChar,
                        j,
                        smgr->ncclcomm,
                        smgr->stream(0)));
                send_ptr += global_expert_count[idx];
            }
            if (local_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclRecv(
                        local_output_buf + expert_ptr[idx] * out_feat, 
                        local_expert_count[idx] * out_feat * sizeof(scalar_t),
                        ncclChar, 
                        j,
                        smgr->ncclcomm,
                        smgr->stream(0)));
            }
        }
        NCCL_SAFE_CALL(ncclGroupEnd());
    }
    delete [] expert_ptr;
    smgr->sync(1);
}


#endif  // FMOE_USE_NCCL