smart_schedule.cpp 5.59 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
#ifdef FMOE_USE_NCCL

#include <cstdlib>
#include <vector>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#include "smart_schedule.h"
9
#include "status.h"
Rick Ho's avatar
Rick Ho committed
10
11
12

long pipeline_gran = -1;

13
14
15
16
17
18
19
20
21
int smart_sch_enabled = 0;

int isSmartSchEnabled() {
    return smart_sch_enabled;
}
void setSmartSchEnabled(int s) {
    smart_sch_enabled = s;
}

Rick Ho's avatar
Rick Ho committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

inline ncclDataType_t getNcclDataType(at::ScalarType t) {
    switch (t) {
        case at::kChar: return ncclInt8;
        case at::kByte: return ncclUint8;
        case at::kFloat: return ncclFloat;
        case at::kDouble: return ncclDouble;
        case at::kInt: return ncclInt32;
        case at::kLong: return ncclInt64;
        case at::kHalf: return ncclHalf;
        case at::kBool: return ncclUint8;
#if defined(ENABLE_NCCL_BF16_DATATYPE)
        case at::kBFloat16: return ncclBfloat16;
#endif
        default: return ncclChar;
    }
}


void _reduce_grad(
        torch::Tensor t,
        long root,
        long expert_size) {
    auto smgr = getCudaStreamManager(t.device().index());

    cudaEvent_t evt_stash;
    cudaEventCreate(&evt_stash);
Rick Ho's avatar
Rick Ho committed
49
    cudaEventRecord(evt_stash, smgr->torchStream());
Rick Ho's avatar
Rick Ho committed
50
    FMOE_SWE(smgr->stream(0), evt_stash);
Rick Ho's avatar
Rick Ho committed
51
52
53
    cudaEventDestroy(evt_stash);

    auto dtype = getNcclDataType(t.scalar_type());
Rick Ho's avatar
Rick Ho committed
54
55
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            t.scalar_type(), "fmoe_cuda_reduce_grad", ([&] {
Rick Ho's avatar
Rick Ho committed
56
57
58
59
60
61
62
63
64
65
            void* buf = (void*)t.data_ptr<scalar_t>();
            NCCL_SAFE_CALL(ncclReduce(buf, buf, expert_size,
                        dtype,
                        ncclSum, root,
                        smgr->ncclcomm, smgr->stream(0)));
        })
    );
}


Rick Ho's avatar
Rick Ho committed
66
std::vector<torch::Tensor> _smart_sch_forward(
Rick Ho's avatar
Rick Ho committed
67
68
69
70
71
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long global_batch_size,
Rick Ho's avatar
Rick Ho committed
72
        long expert_size,
Rick Ho's avatar
Rick Ho committed
73
        long n_workers,
Rick Ho's avatar
Rick Ho committed
74
75
76
77
        py::function forward_fn,
        py::function get_param_fn,
        py::function stash_fn,
        py::function pop_fn) {
Rick Ho's avatar
Rick Ho committed
78
79
80
81
82
83
84
    if (pipeline_gran == -1) {
        char* p = getenv("FMOE_FASTER_GROUP_SIZE");
        if (p) {
            pipeline_gran = atoi(p);
        } else {
            pipeline_gran = 4;
        }
85
        setSmartSchEnabled(1);
Rick Ho's avatar
Rick Ho committed
86
87
88
89
90
91
92
93
94
    }

    auto smgr = getCudaStreamManager(input_buf.device().index());
    int rank;
    NCCL_SAFE_CALL(ncclCommUserRank(smgr->ncclcomm, &rank));

    const auto num_expert = local_expert_count.size(0) / n_workers;
    const auto d_model = input_buf.size(1);

Rick Ho's avatar
Rick Ho committed
95
    // TODO: maybe empty is faster
Rick Ho's avatar
Rick Ho committed
96
97
98
99
    auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
    auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
    auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});

Rick Ho's avatar
Rick Ho committed
100
101
102
103
104
105
    std::vector<torch::Tensor> params;
    auto stored_models_ = stored_models.data_ptr<bool>();
    for (long i = 0; i < num_expert * n_workers; ++i) {
        if (stored_models_[i]) {
            torch::Tensor t = input_buf.new_empty({expert_size});
            if (i / num_expert == rank) {
106
                get_param_fn(t, i % num_expert);
Rick Ho's avatar
Rick Ho committed
107
108
109
110
111
            }
            params.push_back(t);
        }
    }

Rick Ho's avatar
Rick Ho committed
112
113
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            input_buf.scalar_type(), "fmoe_cuda_smart_sch_forward", ([&] {
Rick Ho's avatar
Rick Ho committed
114
115
        fmoe_cuda_fused_forward_impl(
            forward_fn,
Rick Ho's avatar
Rick Ho committed
116
117
            stash_fn,
            pop_fn,
Rick Ho's avatar
Rick Ho committed
118
            input_buf.device(),
Rick Ho's avatar
Rick Ho committed
119
            params,
Rick Ho's avatar
Rick Ho committed
120
121
122
123
124
125
126
127
128

            input_buf.data_ptr<scalar_t>(),
            global_input_buf.data_ptr<scalar_t>(),
            global_output_buf.data_ptr<scalar_t>(),
            output_buf.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
Rick Ho's avatar
Rick Ho committed
129
            d_model, num_expert, rank, n_workers, expert_size,
Rick Ho's avatar
Rick Ho committed
130
131
            pipeline_gran, smgr);
    }));
Rick Ho's avatar
Rick Ho committed
132
    return {output_buf, global_input_buf};
Rick Ho's avatar
Rick Ho committed
133
134
}

Rick Ho's avatar
Rick Ho committed
135
torch::Tensor _smart_sch_backward(
Rick Ho's avatar
Rick Ho committed
136
137
138
139
140
        torch::Tensor grad_out,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long buf_batch_size,
Rick Ho's avatar
Rick Ho committed
141
142
        long global_batch_size,
        long n_workers,
Rick Ho's avatar
Rick Ho committed
143
144
145
146
147
        py::function backward_fn,
        py::function stash_fn,
        py::function pop_fn,
        py::function collect_fn,
        py::function set_grad_fn) {
Rick Ho's avatar
Rick Ho committed
148
    const auto num_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
149
    auto smgr = getCudaStreamManager(grad_out.device().index());
Rick Ho's avatar
Rick Ho committed
150
151
    int rank;
    ncclCommUserRank(smgr->ncclcomm, &rank);
Rick Ho's avatar
Rick Ho committed
152
153
154
155
    const auto d_model = grad_out.size(1);
    auto global_grad_out = grad_out.new_zeros({global_batch_size, d_model});
    auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
    auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});
Rick Ho's avatar
Rick Ho committed
156

Rick Ho's avatar
Rick Ho committed
157
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
Rick Ho's avatar
Rick Ho committed
158
            "fmoe_cuda_smartsch_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
159
        fmoe_cuda_fused_backward_impl(
Rick Ho's avatar
Rick Ho committed
160
            backward_fn,
Rick Ho's avatar
Rick Ho committed
161
162
163
164
            stash_fn,
            pop_fn,
            collect_fn,
            set_grad_fn,
Rick Ho's avatar
Rick Ho committed
165
            grad_out.device(),
Rick Ho's avatar
Rick Ho committed
166
167
168
169
170
171
172
173
174

            grad_out.data_ptr<scalar_t>(),
            global_grad_out.data_ptr<scalar_t>(),
            global_grad_in.data_ptr<scalar_t>(),
            grad_in.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
Rick Ho's avatar
Rick Ho committed
175
            d_model, num_expert, rank, n_workers,
Rick Ho's avatar
Rick Ho committed
176
177
            pipeline_gran, smgr);
    }));
Rick Ho's avatar
Rick Ho committed
178
    return grad_in;
Rick Ho's avatar
Rick Ho committed
179
180
181
}
#endif