global_exchange.h 3.51 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
#include "stream_manager.h"
#ifdef FMOE_USE_NCCL

void fmoe_cuda_expert_exchange_impl(
Rick Ho's avatar
Rick Ho committed
5
6
        const long* local_expert_count,
        long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
7
        int n_expert, int world_size,
Rick Ho's avatar
Rick Ho committed
8
9
        CudaStreamManager* smgr);

Rick Ho's avatar
Rick Ho committed
10
11
12
13
14
15
16

template<typename scalar_t>
void fmoe_cuda_global_scatter_impl(
    const scalar_t* local_input_buf,
    const long* local_expert_count,
    const long* global_expert_count,
    scalar_t* input_buf,
Rick Ho's avatar
Rick Ho committed
17
    size_t in_feat, size_t n_expert, size_t world_size,
Rick Ho's avatar
Rick Ho committed
18
19
20
21
    CudaStreamManager* smgr) {
    // assert world_size > 1
    int recv_ptr = 0;
    /* TODO: may save for backward */
Rick Ho's avatar
Rick Ho committed
22
    long*expert_ptr = new long[n_expert * world_size];
Rick Ho's avatar
Rick Ho committed
23
    expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
24
    for (size_t i = 1; i < n_expert * world_size; ++i) {
Rick Ho's avatar
Rick Ho committed
25
26
27
        expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
    }

Rick Ho's avatar
Rick Ho committed
28
    for (size_t i = 0; i < n_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
29
30
        NCCL_SAFE_CALL(ncclGroupStart());
        for (size_t j = 0; j < world_size; ++j) {
Rick Ho's avatar
Rick Ho committed
31
            int idx = i + j * n_expert;
Rick Ho's avatar
Rick Ho committed
32
33
            if (local_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclSend(
Rick Ho's avatar
Rick Ho committed
34
                        local_input_buf + expert_ptr[idx] * in_feat,
Rick Ho's avatar
Rick Ho committed
35
                        local_expert_count[idx] * in_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
36
                        ncclChar,
Rick Ho's avatar
Rick Ho committed
37
38
                        j,
                        smgr->ncclcomm,
Rick Ho's avatar
Rick Ho committed
39
                        smgr->torchStream()));
Rick Ho's avatar
Rick Ho committed
40
41
42
43
44
45
46
47
            }
            if (global_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclRecv(
                        input_buf + recv_ptr * in_feat,
                        global_expert_count[idx] * in_feat * sizeof(scalar_t),
                        ncclChar,
                        j,
                        smgr->ncclcomm,
Rick Ho's avatar
Rick Ho committed
48
                        smgr->torchStream()));
Rick Ho's avatar
Rick Ho committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
                recv_ptr += global_expert_count[idx];
            }
        }
        NCCL_SAFE_CALL(ncclGroupEnd());
    }
    delete [] expert_ptr;
}

template<typename scalar_t>
void fmoe_cuda_global_gather_impl(
    const scalar_t* output_buf,
    const long* local_expert_count,
    const long* global_expert_count,
    scalar_t* local_output_buf,
Rick Ho's avatar
Rick Ho committed
63
    size_t out_feat, size_t n_expert, size_t world_size,
Rick Ho's avatar
Rick Ho committed
64
65
66
    CudaStreamManager* smgr) {
    long send_ptr = 0;
    /* TODO: may save for backward */
Rick Ho's avatar
Rick Ho committed
67
    long *expert_ptr = new long[n_expert * world_size];
Rick Ho's avatar
Rick Ho committed
68
    expert_ptr[0] = 0;
Rick Ho's avatar
Rick Ho committed
69
    for (size_t i = 1; i < n_expert * world_size; ++i) {
Rick Ho's avatar
Rick Ho committed
70
71
72
        expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
    }

Rick Ho's avatar
Rick Ho committed
73
    for (size_t i = 0; i < n_expert; ++i) {
Rick Ho's avatar
Rick Ho committed
74
75
        NCCL_SAFE_CALL(ncclGroupStart());
        for (size_t j = 0; j < world_size; ++j) {
Rick Ho's avatar
Rick Ho committed
76
            int idx = i + j * n_expert;
Rick Ho's avatar
Rick Ho committed
77
78
79
80
81
82
83
            if (global_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclSend(
                        output_buf + send_ptr * out_feat,
                        global_expert_count[idx] * out_feat * sizeof(scalar_t),
                        ncclChar,
                        j,
                        smgr->ncclcomm,
Rick Ho's avatar
Rick Ho committed
84
                        smgr->torchStream()));
Rick Ho's avatar
Rick Ho committed
85
86
87
88
                send_ptr += global_expert_count[idx];
            }
            if (local_expert_count[idx]) {
                NCCL_SAFE_CALL(ncclRecv(
Rick Ho's avatar
Rick Ho committed
89
                        local_output_buf + expert_ptr[idx] * out_feat,
Rick Ho's avatar
Rick Ho committed
90
                        local_expert_count[idx] * out_feat * sizeof(scalar_t),
Rick Ho's avatar
Rick Ho committed
91
                        ncclChar,
Rick Ho's avatar
Rick Ho committed
92
93
                        j,
                        smgr->ncclcomm,
Rick Ho's avatar
Rick Ho committed
94
                        smgr->torchStream()));
Rick Ho's avatar
Rick Ho committed
95
96
97
98
99
100
101
102
103
            }
        }
        NCCL_SAFE_CALL(ncclGroupEnd());
    }
    delete [] expert_ptr;
}


#endif  // FMOE_USE_NCCL