smart_schedule.cpp 4.18 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#ifdef FMOE_USE_NCCL

#include <cstdlib>
#include <vector>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#include "smart_schedule.h"

long pipeline_gran = -1;

torch::Tensor _smart_sch_forward(
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long global_batch_size,
        long n_workers,
        py::function forward_fn) {

    if (pipeline_gran == -1) {
        char* p = getenv("FMOE_FASTER_GROUP_SIZE");
        if (p) {
            pipeline_gran = atoi(p);
        } else {
            pipeline_gran = 4;
        }
    }

    auto smgr = getCudaStreamManager(input_buf.device().index());
    int rank;
    NCCL_SAFE_CALL(ncclCommUserRank(smgr->ncclcomm, &rank));

    const auto num_expert = local_expert_count.size(0) / n_workers;
    const auto d_model = input_buf.size(1);

    auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
    auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
    
    auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), 
            "fmoe_cuda_fused_forward", ([&] {
        fmoe_cuda_fused_forward_impl(
            forward_fn,
            input_buf.device(),

            input_buf.data_ptr<scalar_t>(),
            global_input_buf.data_ptr<scalar_t>(),
            global_output_buf.data_ptr<scalar_t>(),
            output_buf.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
            d_model, num_expert, rank, n_workers,
            pipeline_gran, smgr);
    }));
    return output_buf;
}

/*
std::vector<torch::Tensor> _fused_backward(
        torch::Tensor input_buf,
        std::vector<std::vector<std::vector<torch::Tensor>>> params,
        torch::Tensor middle_buf,
        torch::Tensor output_buf,
        torch::Tensor grad_out,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor inp,
        torch::Tensor stored_models,
        
        long global_batch_size,
        long buf_batch_size,
        long n_workers, bool has_bias) {
    const auto num_expert = local_expert_count.size(0) / n_workers;
    
    auto smgr = getCudaStreamManager(input_buf.device().index());
    int rank;
    ncclCommUserRank(smgr->ncclcomm, &rank);
    
    const auto d_hidden = params[rank][0][0].size(1);
    const auto d_model = params[rank][0][0].size(2);


    auto global_grad_out = input_buf.new_zeros({global_batch_size, d_model});
    auto grad_middle = input_buf.new_zeros({global_batch_size, d_hidden});
    auto global_grad_in = input_buf.new_zeros({global_batch_size, d_model});
    
    auto grad_in = input_buf.new_zeros({buf_batch_size, d_model});
    
    for (auto node : params)
        for (auto expert : node)
            for (int i = 0; i < expert.size(); i++) {
                // create the respective gradient of each tensor
                CHECK_INPUT(expert[i]);
                if (expert[i].grad().defined()) {
                    CHECK_INPUT(expert[i].grad());
                    continue;
                }

                expert[i].mutable_grad() = input_buf.new_zeros(expert[i].sizes());
            }

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), 
            "fmoe_cuda_fused_backward", ([&] {
        fmoe_cuda_fused_backward_impl(
            input_buf.data_ptr<scalar_t>(),
            inp.data_ptr<scalar_t>(),
            params,

            middle_buf.data_ptr<scalar_t>(),
            output_buf.data_ptr<scalar_t>(),
            grad_out.data_ptr<scalar_t>(),

            global_grad_out.data_ptr<scalar_t>(),
            global_grad_in.data_ptr<scalar_t>(),

            grad_middle.data_ptr<scalar_t>(),
            grad_in.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
            d_model, d_hidden, num_expert, rank, n_workers, has_bias,
            pipeline_gran, smgr);
    }));
    return {grad_in,};
}
*/
#endif