smart_schedule.cpp 5.65 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
#ifdef FMOE_USE_NCCL

#include <cstdlib>
#include <vector>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#include "smart_schedule.h"
9
#include "status.h"
Rick Ho's avatar
Rick Ho committed
10
11
12

long pipeline_gran = -1;

13
14
15
16
17
18
19
20
21
int smart_sch_enabled = 0;

int isSmartSchEnabled() {
    return smart_sch_enabled;
}
void setSmartSchEnabled(int s) {
    smart_sch_enabled = s;
}

Rick Ho's avatar
Rick Ho committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

inline ncclDataType_t getNcclDataType(at::ScalarType t) {
    switch (t) {
        case at::kChar: return ncclInt8;
        case at::kByte: return ncclUint8;
        case at::kFloat: return ncclFloat;
        case at::kDouble: return ncclDouble;
        case at::kInt: return ncclInt32;
        case at::kLong: return ncclInt64;
        case at::kHalf: return ncclHalf;
        case at::kBool: return ncclUint8;
#if defined(ENABLE_NCCL_BF16_DATATYPE)
        case at::kBFloat16: return ncclBfloat16;
#endif
        default: return ncclChar;
    }
}


void _reduce_grad(
        torch::Tensor t,
        long root,
        long expert_size) {
    auto smgr = getCudaStreamManager(t.device().index());

    auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
    cudaEvent_t evt_stash;
    cudaEventCreate(&evt_stash);
    cudaEventRecord(evt_stash, torch_stream);
Rick Ho's avatar
Rick Ho committed
51
    FMOE_SWE(smgr->stream(0), evt_stash);
Rick Ho's avatar
Rick Ho committed
52
53
54
    cudaEventDestroy(evt_stash);

    auto dtype = getNcclDataType(t.scalar_type());
Rick Ho's avatar
Rick Ho committed
55
56
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            t.scalar_type(), "fmoe_cuda_reduce_grad", ([&] {
Rick Ho's avatar
Rick Ho committed
57
58
59
60
61
62
63
64
65
66
            void* buf = (void*)t.data_ptr<scalar_t>();
            NCCL_SAFE_CALL(ncclReduce(buf, buf, expert_size,
                        dtype,
                        ncclSum, root,
                        smgr->ncclcomm, smgr->stream(0)));
        })
    );
}


Rick Ho's avatar
Rick Ho committed
67
std::vector<torch::Tensor> _smart_sch_forward(
Rick Ho's avatar
Rick Ho committed
68
69
70
71
72
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long global_batch_size,
Rick Ho's avatar
Rick Ho committed
73
        long expert_size,
Rick Ho's avatar
Rick Ho committed
74
        long n_workers,
Rick Ho's avatar
Rick Ho committed
75
76
77
78
        py::function forward_fn,
        py::function get_param_fn,
        py::function stash_fn,
        py::function pop_fn) {
Rick Ho's avatar
Rick Ho committed
79
80
81
82
83
84
85
    if (pipeline_gran == -1) {
        char* p = getenv("FMOE_FASTER_GROUP_SIZE");
        if (p) {
            pipeline_gran = atoi(p);
        } else {
            pipeline_gran = 4;
        }
86
        setSmartSchEnabled(1);
Rick Ho's avatar
Rick Ho committed
87
88
89
90
91
92
93
94
95
    }

    auto smgr = getCudaStreamManager(input_buf.device().index());
    int rank;
    NCCL_SAFE_CALL(ncclCommUserRank(smgr->ncclcomm, &rank));

    const auto num_expert = local_expert_count.size(0) / n_workers;
    const auto d_model = input_buf.size(1);

Rick Ho's avatar
Rick Ho committed
96
    // TODO: maybe empty is faster
Rick Ho's avatar
Rick Ho committed
97
98
99
100
    auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
    auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
    auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});

Rick Ho's avatar
Rick Ho committed
101
102
103
104
105
106
    std::vector<torch::Tensor> params;
    auto stored_models_ = stored_models.data_ptr<bool>();
    for (long i = 0; i < num_expert * n_workers; ++i) {
        if (stored_models_[i]) {
            torch::Tensor t = input_buf.new_empty({expert_size});
            if (i / num_expert == rank) {
107
                get_param_fn(t, i % num_expert);
Rick Ho's avatar
Rick Ho committed
108
109
110
111
112
            }
            params.push_back(t);
        }
    }

Rick Ho's avatar
Rick Ho committed
113
114
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            input_buf.scalar_type(), "fmoe_cuda_smart_sch_forward", ([&] {
Rick Ho's avatar
Rick Ho committed
115
116
        fmoe_cuda_fused_forward_impl(
            forward_fn,
Rick Ho's avatar
Rick Ho committed
117
118
            stash_fn,
            pop_fn,
Rick Ho's avatar
Rick Ho committed
119
            input_buf.device(),
Rick Ho's avatar
Rick Ho committed
120
            params,
Rick Ho's avatar
Rick Ho committed
121
122
123
124
125
126
127
128
129

            input_buf.data_ptr<scalar_t>(),
            global_input_buf.data_ptr<scalar_t>(),
            global_output_buf.data_ptr<scalar_t>(),
            output_buf.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
Rick Ho's avatar
Rick Ho committed
130
            d_model, num_expert, rank, n_workers, expert_size,
Rick Ho's avatar
Rick Ho committed
131
132
            pipeline_gran, smgr);
    }));
Rick Ho's avatar
Rick Ho committed
133
    return {output_buf, global_input_buf};
Rick Ho's avatar
Rick Ho committed
134
135
}

Rick Ho's avatar
Rick Ho committed
136
torch::Tensor _smart_sch_backward(
Rick Ho's avatar
Rick Ho committed
137
138
139
140
141
        torch::Tensor grad_out,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long buf_batch_size,
Rick Ho's avatar
Rick Ho committed
142
143
        long global_batch_size,
        long n_workers,
Rick Ho's avatar
Rick Ho committed
144
145
146
147
148
        py::function backward_fn,
        py::function stash_fn,
        py::function pop_fn,
        py::function collect_fn,
        py::function set_grad_fn) {
Rick Ho's avatar
Rick Ho committed
149
    const auto num_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
150
    auto smgr = getCudaStreamManager(grad_out.device().index());
Rick Ho's avatar
Rick Ho committed
151
152
    int rank;
    ncclCommUserRank(smgr->ncclcomm, &rank);
Rick Ho's avatar
Rick Ho committed
153
154
155
156
    const auto d_model = grad_out.size(1);
    auto global_grad_out = grad_out.new_zeros({global_batch_size, d_model});
    auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
    auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});
Rick Ho's avatar
Rick Ho committed
157

Rick Ho's avatar
Rick Ho committed
158
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
Rick Ho's avatar
Rick Ho committed
159
            "fmoe_cuda_smartsch_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
160
        fmoe_cuda_fused_backward_impl(
Rick Ho's avatar
Rick Ho committed
161
            backward_fn,
Rick Ho's avatar
Rick Ho committed
162
163
164
165
            stash_fn,
            pop_fn,
            collect_fn,
            set_grad_fn,
Rick Ho's avatar
Rick Ho committed
166
            grad_out.device(),
Rick Ho's avatar
Rick Ho committed
167
168
169
170
171
172
173
174
175

            grad_out.data_ptr<scalar_t>(),
            global_grad_out.data_ptr<scalar_t>(),
            global_grad_in.data_ptr<scalar_t>(),
            grad_in.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
Rick Ho's avatar
Rick Ho committed
176
            d_model, num_expert, rank, n_workers,
Rick Ho's avatar
Rick Ho committed
177
178
            pipeline_gran, smgr);
    }));
Rick Ho's avatar
Rick Ho committed
179
    return grad_in;
Rick Ho's avatar
Rick Ho committed
180
181
182
}
#endif