fmoe_cuda.cpp 4.41 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
#include <iostream>
#include <vector>
Rick Ho's avatar
Rick Ho committed
3
#include <torch/csrc/autograd/custom_function.h>
Rick Ho's avatar
Rick Ho committed
4
5
#include <torch/extension.h>

Rick Ho's avatar
Rick Ho committed
6
// global_exchange
Rick Ho's avatar
Rick Ho committed
7
#ifdef FMOE_USE_NCCL
Rick Ho's avatar
Rick Ho committed
8
9
10

#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
        (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
11
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
Rick Ho's avatar
Rick Ho committed
12
13
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else
Rick Ho's avatar
Rick Ho committed
14
#include <c10d/ProcessGroupNCCL.hpp>
Rick Ho's avatar
Rick Ho committed
15
16
#endif

17
torch::Tensor _expert_exchange(
Rick Ho's avatar
Rick Ho committed
18
19
        torch::Tensor local_expert_count,
        long n_expert, long n_workers);
20
torch::Tensor _global_scatter(
Rick Ho's avatar
Rick Ho committed
21
22
23
24
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers);
25
torch::Tensor _global_gather(
Rick Ho's avatar
Rick Ho committed
26
27
28
29
        torch::Tensor output_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        long batch_size, long n_workers);
30
31
32
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void _ensure_nccl(c10d::ProcessGroup& p, torch::Tensor t);
#else
Rick Ho's avatar
Rick Ho committed
33
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
34
35
#endif  // TORCH_VERSION

Rick Ho's avatar
Rick Ho committed
36
37
#endif  // FMOE_USE_NCCL

Rick Ho's avatar
Rick Ho committed
38
// local_exchange
39
void _assign_pos(
40
41
42
        torch::Tensor cum_count,
        torch::Tensor gate,
        torch::Tensor pos);
43
44
45
void _expert_count(
        torch::Tensor gate_idx,
        torch::Tensor expert_count);
Rick Ho's avatar
Rick Ho committed
46

Rick Ho's avatar
Rick Ho committed
47
// parallel_linear
48
torch::Tensor _linear_forward(
Rick Ho's avatar
Rick Ho committed
49
        torch::Tensor input_buf,
50
        torch::Tensor expert_count,
Rick Ho's avatar
Rick Ho committed
51
        torch::Tensor weight,
52
53
        at::optional<torch::Tensor> bias
        );
Rick Ho's avatar
Rick Ho committed
54
std::vector<torch::Tensor> _linear_backward(
55
56
57
58
59
60
        torch::Tensor grad_output_buf,
        torch::Tensor input_buf,
        torch::Tensor expert_count,
        torch::Tensor weight,
        at::optional<torch::Tensor> bias
        );
Rick Ho's avatar
Rick Ho committed
61

Rick Ho's avatar
Rick Ho committed
62
// balancing
63
torch::Tensor _limit_by_capacity(
Rick Ho's avatar
Rick Ho committed
64
65
        torch::Tensor expert_count, torch::Tensor capacity,
        long n_expert, long n_experts);
66
torch::Tensor _prune_gate_by_capacity(
Rick Ho's avatar
Rick Ho committed
67
68
        torch::Tensor gate_idx, torch::Tensor expert_count,
        long n_expert, long n_worker);
Rick Ho's avatar
Rick Ho committed
69
70
71
std::vector<torch::Tensor> _swipe_once(
        torch::Tensor gate_idx, torch::Tensor capacity_tensor,
        long n_expert, long n_worker, long bias);
Rick Ho's avatar
Rick Ho committed
72

Rick Ho's avatar
Rick Ho committed
73
// smart scheduling
Rick Ho's avatar
Rick Ho committed
74
std::vector<torch::Tensor> _smart_sch_forward(
Rick Ho's avatar
Rick Ho committed
75
76
77
78
        torch::Tensor input_buf,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
Rick Ho's avatar
Rick Ho committed
79
80
81
82
83
84
85
        long global_batch_size,
        long expert_size,
        long n_workers,
        py::function forward_fn,
        py::function get_param_fn,
        py::function stash_fn,
        py::function pop_fn);
Rick Ho's avatar
Rick Ho committed
86
87
88
89
90
91
92
93
torch::Tensor _smart_sch_backward(
        torch::Tensor grad_out,
        torch::Tensor local_expert_count,
        torch::Tensor global_expert_count,
        torch::Tensor stored_models,
        long buf_batch_size,
        long global_batch_size,
        long n_workers,
Rick Ho's avatar
Rick Ho committed
94
        py::function backward_fn,
Rick Ho's avatar
Rick Ho committed
95
96
        py::function stash_fn,
        py::function pop_fn,
Rick Ho's avatar
Rick Ho committed
97
98
        py::function collect_fn,
        py::function set_grad_fn);
Rick Ho's avatar
Rick Ho committed
99
100
101
102
void _reduce_grad(
        torch::Tensor t,
        long root,
        long expert_size);
Rick Ho's avatar
Rick Ho committed
103

Rick Ho's avatar
Rick Ho committed
104
105
106
107
108
109
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL
    m.def("expert_exchange", &_expert_exchange, "FastMoE expert exchange (CUDA)");
    m.def("global_scatter", &_global_scatter, "FastMoE global scatter (CUDA)");
    m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)");
    m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
Rick Ho's avatar
Rick Ho committed
110
    m.def("swipe_once", &_swipe_once, "SWIPE balance strategy(CUDA)");
Rick Ho's avatar
Rick Ho committed
111
112

    m.def("smart_sch_forward", &_smart_sch_forward, "E2E MoE layer forward with smart scheduling");
Rick Ho's avatar
Rick Ho committed
113
    m.def("smart_sch_backward", &_smart_sch_backward, "E2E MoE layer backward with smart scheduling");
Rick Ho's avatar
Rick Ho committed
114
    m.def("reduce_grad", &_reduce_grad, "Reduce gradients over FastMoE's communication stream");
Rick Ho's avatar
Rick Ho committed
115
116
#endif

117
118
    m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)");
    m.def("assign_pos", &_assign_pos, "FastMoE assign pos by gate (CUDA)");
Rick Ho's avatar
Rick Ho committed
119
120
121

    m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)");
    m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)");
Rick Ho's avatar
Rick Ho committed
122

Rick Ho's avatar
Rick Ho committed
123
124
    m.def("limit_by_capacity", &_limit_by_capacity, "FastMoE limit experts by capacity(CUDA)");
    m.def("prune_gate_by_capacity", &_prune_gate_by_capacity, "FastMoE prune gate by capacity(CUDA)");
Rick Ho's avatar
Rick Ho committed
125
}