smart_schedule.cpp 5.74 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifdef FMOE_USE_NCCL

#include <cstdlib>
#include <vector>
#include <torch/extension.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>

#include "../../hip/fastermoe/smart_schedule.h"
#include "../../hip/fastermoe/status.h"

long pipeline_gran = -1;

int smart_sch_enabled = 0;

int isSmartSchEnabled() {
    return smart_sch_enabled;
}
void setSmartSchEnabled(int s) {
    smart_sch_enabled = s;
}


inline ncclDataType_t getNcclDataType(at::ScalarType t) {
    switch (t) {
        case at::kChar: return ncclInt8;
        case at::kByte: return ncclUint8;
        case at::kFloat: return ncclFloat;
        case at::kDouble: return ncclDouble;
        case at::kInt: return ncclInt32;
        case at::kLong: return ncclInt64;
        case at::kHalf: return ncclHalf;
        case at::kBool: return ncclUint8;
#if defined(ENABLE_NCCL_BF16_DATATYPE)
        case at::kBFloat16: return ncclBfloat16;
#endif
        default: return ncclChar;
    }
}


void _reduce_grad(
        torch::Tensor t,
        long root,
        long expert_size) {
    auto smgr = getCudaStreamManager(t.device().index());

    hipEvent_t evt_stash;
    hipEventCreate(&evt_stash);
    hipEventRecord(evt_stash, smgr->torchStream());
    FMOE_SWE(smgr->stream(0), evt_stash);
    hipEventDestroy(evt_stash);

    auto dtype = getNcclDataType(t.scalar_type());
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            t.scalar_type(), "fmoe_cuda_reduce_grad", ([&] {
            void* buf = (void*)t.data_ptr<scalar_t>();
            NCCL_SAFE_CALL(ncclReduce(buf, buf, expert_size,
                        dtype,
                        ncclSum, root,
                        smgr->ncclcomm, smgr->stream(0)));
        })
    );
}


std::vector<torch::Tensor> _smart_sch_forward(
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long global_batch_size,
        long expert_size,
        long n_workers,
        py::function forward_fn,
        py::function get_param_fn,
        py::function stash_fn,
        py::function pop_fn) {
    if (pipeline_gran == -1) {
        char* p = getenv("FMOE_FASTER_GROUP_SIZE");
        if (p) {
            pipeline_gran = atoi(p);
        } else {
            pipeline_gran = 4;
        }
        setSmartSchEnabled(1);
    }

    auto smgr = getCudaStreamManager(input_buf.device().index());
    int rank;
    NCCL_SAFE_CALL(ncclCommUserRank(smgr->ncclcomm, &rank));

    const auto num_expert = local_expert_count.size(0) / n_workers;
    const auto d_model = input_buf.size(1);

    // TODO: maybe empty is faster
    auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
    auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
    auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});

    std::vector<torch::Tensor> params;
    auto stored_models_ = stored_models.data_ptr<bool>();
    for (long i = 0; i < num_expert * n_workers; ++i) {
        if (stored_models_[i]) {
            torch::Tensor t = input_buf.new_empty({expert_size});
            if (i / num_expert == rank) {
                get_param_fn(t, i % num_expert);
            }
            params.push_back(t);
        }
    }

    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
            input_buf.scalar_type(), "fmoe_cuda_smart_sch_forward", ([&] {
        fmoe_cuda_fused_forward_impl(
            forward_fn,
            stash_fn,
            pop_fn,
            input_buf.device(),
            params,

            input_buf.data_ptr<scalar_t>(),
            global_input_buf.data_ptr<scalar_t>(),
            global_output_buf.data_ptr<scalar_t>(),
            output_buf.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
            d_model, num_expert, rank, n_workers, expert_size,
            pipeline_gran, smgr);
    }));
    return {output_buf, global_input_buf};
}

torch::Tensor _smart_sch_backward(
        torch::Tensor grad_out,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long buf_batch_size,
        long global_batch_size,
        long n_workers,
        py::function backward_fn,
        py::function stash_fn,
        py::function pop_fn,
        py::function collect_fn,
        py::function set_grad_fn) {
    const auto num_expert = local_expert_count.size(0) / n_workers;
    auto smgr = getCudaStreamManager(grad_out.device().index());
    int rank;
    ncclCommUserRank(smgr->ncclcomm, &rank);
    const auto d_model = grad_out.size(1);
    auto global_grad_out = grad_out.new_zeros({global_batch_size, d_model});
    auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
    auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
            "fmoe_cuda_smartsch_backward", ([&] {
        fmoe_cuda_fused_backward_impl(
            backward_fn,
            stash_fn,
            pop_fn,
            collect_fn,
            set_grad_fn,
            grad_out.device(),

            grad_out.data_ptr<scalar_t>(),
            global_grad_out.data_ptr<scalar_t>(),
            global_grad_in.data_ptr<scalar_t>(),
            grad_in.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
            d_model, num_expert, rank, n_workers,
            pipeline_gran, smgr);
    }));
    return grad_in;
}
#endif