global_exchange.h 3.54 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
#include "stream_manager.h"
#ifdef FMOE_USE_NCCL

void fmoe_cuda_expert_exchange_impl(
Rick Ho's avatar
Rick Ho committed
5
6
        const long* local_expert_count,
        long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
7
        int n_expert, int world_size,
Rick Ho's avatar
Rick Ho committed
8
9
        CudaStreamManager* smgr);

Rick Ho's avatar
Rick Ho committed
10
11
12
13
14
15
16

template<typename scalar_t>
void fmoe_cuda_global_scatter_impl(
    const scalar_t* local_input_buf,
    const long* local_expert_count,
    const long* global_expert_count,
    scalar_t* input_buf,
Rick Ho's avatar
Rick Ho committed
17
    size_t in_feat, size_t n_expert, size_t world_size,
Rick Ho's avatar
Rick Ho committed
18
19
20
21
    CudaStreamManager* smgr) {
    // assert world_size > 1
    int recv_ptr = 0;
    /* TODO: may save for backward */
Rick Ho's avatar
Rick Ho committed
22
    long*expert_ptr = new long[n_expert * world_size];
Rick Ho's avatar
Rick Ho committed
23
    expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
24
    for (size_t i = 1; i < n_expert * world_size; ++i) {
Rick Ho's avatar
Rick Ho committed
25
26
27
        expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
    }

Rick Ho's avatar
Rick Ho committed
28
    for (size_t i = 0; i < n_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
29
30
        NCCL_SAFE_CALL(ncclGroupStart());
        for (size_t j = 0; j < world_size; ++j) {
Rick Ho's avatar
Rick Ho committed
31
            int idx = i + j * n_expert;
Rick Ho's avatar
Rick Ho committed
32
33
            if (local_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclSend(
Rick Ho's avatar
Rick Ho committed
34
                        local_input_buf + expert_ptr[idx] * in_feat,
Rick Ho's avatar
Rick Ho committed
35
                        local_expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
36
                        ncclChar,
Rick Ho's avatar
Rick Ho committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
                        j,
                        smgr->ncclcomm,
                        smgr->stream(0)));
            }
            if (global_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclRecv(
                        input_buf + recv_ptr * in_feat,
                        global_expert_count[idx] * in_feat * sizeof(scalar_t),
                        ncclChar,
                        j,
                        smgr->ncclcomm,
                        smgr->stream(0)));
                recv_ptr += global_expert_count[idx];
            }
        }
        NCCL_SAFE_CALL(ncclGroupEnd());
    }
    delete [] expert_ptr;
    smgr->sync(1);
}

template<typename scalar_t>
void fmoe_cuda_global_gather_impl(
    const scalar_t* output_buf,
    const long* local_expert_count,
    const long* global_expert_count,
    scalar_t* local_output_buf,
Rick Ho's avatar
Rick Ho committed
64
    size_t out_feat, size_t n_expert, size_t world_size,
Rick Ho's avatar
Rick Ho committed
65
66
67
    CudaStreamManager* smgr) {
    long send_ptr = 0;
    /* TODO: may save for backward */
Rick Ho's avatar
Rick Ho committed
68
    long *expert_ptr = new long[n_expert * world_size];
Rick Ho's avatar
Rick Ho committed
69
    expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
70
    for (size_t i = 1; i < n_expert * world_size; ++i) {
Rick Ho's avatar
Rick Ho committed
71
72
73
        expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
    }

Rick Ho's avatar
Rick Ho committed
74
    for (size_t i = 0; i < n_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
75
76
        NCCL_SAFE_CALL(ncclGroupStart());
        for (size_t j = 0; j < world_size; ++j) {
Rick Ho's avatar
Rick Ho committed
77
            int idx = i + j * n_expert;
Rick Ho's avatar
Rick Ho committed
78
79
80
81
82
83
84
85
86
87
88
89
            if (global_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclSend(
                        output_buf + send_ptr * out_feat,
                        global_expert_count[idx] * out_feat * sizeof(scalar_t),
                        ncclChar,
                        j,
                        smgr->ncclcomm,
                        smgr->stream(0)));
                send_ptr += global_expert_count[idx];
            }
            if (local_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclRecv(
Rick Ho's avatar
Rick Ho committed
90
                        local_output_buf + expert_ptr[idx] * out_feat,
Rick Ho's avatar
Rick Ho committed
91
                        local_expert_count[idx] * out_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
92
                        ncclChar,
Rick Ho's avatar
Rick Ho committed
93
94
95
96
97
98
99
100
101
102
103
104
105
                        j,
                        smgr->ncclcomm,
                        smgr->stream(0)));
            }
        }
        NCCL_SAFE_CALL(ncclGroupEnd());
    }
    delete [] expert_ptr;
    smgr->sync(1);
}


#endif  // FMOE_USE_NCCL