smart_schedule.cpp 3.43 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
#ifdef FMOE_USE_NCCL

#include <cstdlib>
#include <vector>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>

#include "smart_schedule.h"
9
#include "status.h"
Rick Ho's avatar
Rick Ho committed
10
11
12

long pipeline_gran = -1;

13
14
15
16
17
18
19
20
21
int smart_sch_enabled = 0;

int isSmartSchEnabled() {
    return smart_sch_enabled;
}
void setSmartSchEnabled(int s) {
    smart_sch_enabled = s;
}

Rick Ho's avatar
Rick Ho committed
22
std::vector<torch::Tensor> _smart_sch_forward(
Rick Ho's avatar
Rick Ho committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long global_batch_size,
        long n_workers,
        py::function forward_fn) {
    if (pipeline_gran == -1) {
        char* p = getenv("FMOE_FASTER_GROUP_SIZE");
        if (p) {
            pipeline_gran = atoi(p);
        } else {
            pipeline_gran = 4;
        }
37
        setSmartSchEnabled(1);
Rick Ho's avatar
Rick Ho committed
38
39
40
41
42
43
44
45
46
    }

    auto smgr = getCudaStreamManager(input_buf.device().index());
    int rank;
    NCCL_SAFE_CALL(ncclCommUserRank(smgr->ncclcomm, &rank));

    const auto num_expert = local_expert_count.size(0) / n_workers;
    const auto d_model = input_buf.size(1);

Rick Ho's avatar
Rick Ho committed
47
    // TODO: maybe empty is faster
Rick Ho's avatar
Rick Ho committed
48
49
50
51
52
53
    auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
    auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
    
    auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), 
Rick Ho's avatar
Rick Ho committed
54
            "fmoe_cuda_smart_sch_forward", ([&] {
Rick Ho's avatar
Rick Ho committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        fmoe_cuda_fused_forward_impl(
            forward_fn,
            input_buf.device(),

            input_buf.data_ptr<scalar_t>(),
            global_input_buf.data_ptr<scalar_t>(),
            global_output_buf.data_ptr<scalar_t>(),
            output_buf.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
            d_model, num_expert, rank, n_workers,
            pipeline_gran, smgr);
    }));
Rick Ho's avatar
Rick Ho committed
70
    return {output_buf, global_input_buf};
Rick Ho's avatar
Rick Ho committed
71
72
}

Rick Ho's avatar
Rick Ho committed
73
torch::Tensor _smart_sch_backward(
Rick Ho's avatar
Rick Ho committed
74
75
76
77
78
        torch::Tensor grad_out,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long buf_batch_size,
Rick Ho's avatar
Rick Ho committed
79
80
81
        long global_batch_size,
        long n_workers,
        py::function backward_fn) {
Rick Ho's avatar
Rick Ho committed
82
    const auto num_expert = local_expert_count.size(0) / n_workers;
Rick Ho's avatar
Rick Ho committed
83
    auto smgr = getCudaStreamManager(grad_out.device().index());
Rick Ho's avatar
Rick Ho committed
84
85
    int rank;
    ncclCommUserRank(smgr->ncclcomm, &rank);
Rick Ho's avatar
Rick Ho committed
86
87
88
89
    const auto d_model = grad_out.size(1);
    auto global_grad_out = grad_out.new_zeros({global_batch_size, d_model});
    auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
    auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});
Rick Ho's avatar
Rick Ho committed
90

Rick Ho's avatar
Rick Ho committed
91
92
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(), 
            "fmoe_cuda_smartsch_backward", ([&] {
Rick Ho's avatar
Rick Ho committed
93
        fmoe_cuda_fused_backward_impl(
Rick Ho's avatar
Rick Ho committed
94
95
            backward_fn,
            grad_out.device(),
Rick Ho's avatar
Rick Ho committed
96
97
98
99
100
101
102
103
104

            grad_out.data_ptr<scalar_t>(),
            global_grad_out.data_ptr<scalar_t>(),
            global_grad_in.data_ptr<scalar_t>(),
            grad_in.data_ptr<scalar_t>(),

            local_expert_count.data_ptr<long>(),
            global_expert_count.data_ptr<long>(),
            stored_models.data_ptr<bool>(),
Rick Ho's avatar
Rick Ho committed
105
            d_model, num_expert, rank, n_workers,
Rick Ho's avatar
Rick Ho committed
106
107
108
109
110
111
            pipeline_gran, smgr);
    }));
    return {grad_in,};
}
#endif