"docs/vscode:/vscode.git/clone" did not exist on "fcd9637c06034ae5ba9769b09f2e422729f7ce56"
moe.cpp 4.84 KB
Newer Older
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
1
2
3
4
5
6
#include <torch/extension.h>

#include <cstdio>
#include <iostream>
#include <vector>

7
#include "moe_cuda_kernel.h"
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
8
9
10
11
12
13

// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

Rick Ho's avatar
Rick Ho committed
14
std::vector<torch::Tensor> moe_expert_count(
15
16
		torch::Tensor gate, 
		size_t num_expert) {
Rick Ho's avatar
Rick Ho committed
17
	CHECK_INPUT(gate);
18
	return moe_cuda_expert_count(gate, num_expert);
Rick Ho's avatar
Rick Ho committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
}

std::vector<torch::Tensor> moe_local_scatter(
		torch::Tensor input,
		torch::Tensor pos) {
	CHECK_INPUT(input);
	return moe_cuda_local_scatter(input, pos);
}

std::vector<torch::Tensor> moe_local_gather(
		torch::Tensor output_buf,
		torch::Tensor pos) {
	CHECK_INPUT(output_buf);
	return moe_cuda_local_gather(output_buf, pos);
}


Jiezhong Qiu's avatar
Jiezhong Qiu committed
36
std::vector<torch::Tensor> moe_forward(
Rick Ho's avatar
Rick Ho committed
37
        torch::Tensor input_buf,     // [batch_size x in_feat]
Rick Ho's avatar
Rick Ho committed
38
        torch::Tensor weight,        // [num_expert x out_feat x in_feat]
Rick Ho's avatar
Rick Ho committed
39
        torch::Tensor expert_count   // [batch_size]
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
40
        ) {
Rick Ho's avatar
Rick Ho committed
41
    CHECK_INPUT(input_buf);
Rick Ho's avatar
Rick Ho committed
42
    CHECK_INPUT(weight);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
43
44
45
46
47
    /*
        The bias term should have been merged into weight. Note the following fact that 
        Wx+b = [W b] [x]
                     [1]  
    */
Rick Ho's avatar
Rick Ho committed
48
    return moe_cuda_forward(input_buf, weight, expert_count);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
49
50
}

Jiezhong Qiu's avatar
Jiezhong Qiu committed
51
std::vector<torch::Tensor> moe_backward(
Rick Ho's avatar
Rick Ho committed
52
53
54
55
        torch::Tensor grad_output_buf, // [batch_size x out_feat]
        torch::Tensor input_buf,       // [batch_size x out_feat]
        torch::Tensor weight,          // [num_expert x out_feat x in_feat]
        torch::Tensor expert_count
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
56
        ) {
Rick Ho's avatar
Rick Ho committed
57
58
    CHECK_INPUT(grad_output_buf);
    CHECK_INPUT(input_buf);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
59
    CHECK_INPUT(weight);
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
60
61
62
63
64
    /*
        The bias term should have been merged into weight. Note the following fact that 
        Wx+b = [W b] [x]
                     [1]  
    */
Rick Ho's avatar
Rick Ho committed
65
    return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count);
Jiezhong Qiu's avatar
can run  
Jiezhong Qiu committed
66
67
}

68
69
#ifdef MOE_USE_NCCL

Rick Ho's avatar
Rick Ho committed
70
71
72
73
74
75
std::vector<torch::Tensor> moe_expert_exchange(
		torch::Tensor local_expert_count,
		size_t num_expert, size_t n_workers) {
	return moe_cuda_expert_exchange(local_expert_count, num_expert, n_workers);
}

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
std::vector<torch::Tensor> moe_global_scatter(
		torch::Tensor input_buf,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		size_t batch_size, size_t n_workers) {
	CHECK_INPUT(input_buf);
	return moe_cuda_global_scatter(input_buf,
		   	local_expert_count, global_expert_count,
			batch_size, n_workers);
}

std::vector<torch::Tensor> moe_global_gather(
		torch::Tensor output_buf,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		size_t batch_size, size_t n_workers) {
	CHECK_INPUT(output_buf);
	return moe_cuda_global_gather(output_buf,
		   	local_expert_count, global_expert_count,
			batch_size, n_workers);
}

Rick Ho's avatar
Rick Ho committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111

std::vector<torch::Tensor> moe_global_fused_forward(
		torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		long global_batch_size, long local_batch_size, long n_workers) {
	CHECK_INPUT(input_buf);
	CHECK_INPUT(weight);
	return moe_cuda_global_fused_forward(
			input_buf, weight, local_expert_count, global_expert_count,
			global_batch_size, local_batch_size, n_workers);
}

Rick Ho's avatar
Rick Ho committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <c10d/ProcessGroupNCCL.hpp>
#include "cuda_stream_manager.h"

class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
	ncclComm_t getcomm(at::Device dev) {
		auto key = std::to_string(dev.index());
		auto v = getNCCLComm(key, {dev}, c10d::OpType::ALLTOALL);
		if (v.size() == 0) {
			std::cerr << "PyTorch has nothing\n";
			return 0;
		}
		return v[0]->getNcclComm();
	}
};

void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
	auto smgr = getCudaStreamManager(t.device().index());
	if (smgr->ncclgood) {
		return;
	}
	HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
	smgr->ncclcomm = h->getcomm(t.device());
	if (smgr->ncclcomm != 0) {
		smgr->ncclgood = 1;
	} else {
		std::cerr << "Nccl initialization failed\n";
	}
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
140
}
Rick Ho's avatar
Rick Ho committed
141
142

#endif  // MOE_USE_NCCL
Jiezhong Qiu's avatar
update  
Jiezhong Qiu committed
143
144

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Rick Ho's avatar
Rick Ho committed
145
146
147
  m.def("expert_count", &moe_expert_count, "MoE expert count (CUDA)");
  m.def("local_scatter", &moe_local_scatter, "MoE local scatter (CUDA)");
  m.def("local_gather", &moe_local_gather, "MoE local gather (CUDA)");
148
#ifdef MOE_USE_NCCL
Rick Ho's avatar
Rick Ho committed
149
  m.def("expert_exchange", &moe_expert_exchange, "MoE expert exchange (CUDA)");
150
151
  m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
  m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
Rick Ho's avatar
Rick Ho committed
152
153
  m.def("global_fused_forward", &moe_global_fused_forward, 
		  "MoE global gather (CUDA)");
Rick Ho's avatar
Rick Ho committed
154
  m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm");
155
#endif
Jiezhong Qiu's avatar
Jiezhong Qiu committed
156
157
  m.def("forward", &moe_forward, "MoE forward (CUDA)");
  m.def("backward", &moe_backward, "MoE backward (CUDA)");
Rick Ho's avatar
Rick Ho committed
158
}