moe_comm_kernel.cu 5.6 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include "moe_cuda_kernel.h"

#include <cstdio>
#include <iostream>
#include <vector>

#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h> 
#include <c10/cuda/CUDAGuard.h>

#include "cuda_stream_manager.h"

#ifdef MOE_USE_NCCL
#include <nccl.h>

void moe_cuda_expert_exchange_impl(
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
		const long* local_expert_count, 
		long* global_expert_count, 
		int num_expert, int world_size,
		CudaStreamManager* smgr) {
	NCCL_SAFE_CALL(ncclGroupStart());
	for (int i = 0; i < world_size; ++i) {
		NCCL_SAFE_CALL(ncclSend(
				local_expert_count + num_expert * i,
				num_expert,
				ncclInt64,
				i,
				smgr->ncclcomm,
				smgr->stream(0)));
		NCCL_SAFE_CALL(ncclRecv(
				global_expert_count + num_expert * i,
				num_expert,
				ncclInt64,
				i,
				smgr->ncclcomm,
				smgr->stream(0)));
Rick Ho's avatar
Rick Ho committed
39
	}
40
41
	NCCL_SAFE_CALL(ncclGroupEnd());
	smgr->sync(1);
Rick Ho's avatar
Rick Ho committed
42
43
44
45
46
47
}

std::vector<torch::Tensor> moe_cuda_expert_exchange(
		torch::Tensor local_expert_count,
		long num_expert, long n_workers) {
    auto global_expert_count = torch::empty_like(local_expert_count);
48
49
	auto smgr = getCudaStreamManager(local_expert_count.device().index());

Rick Ho's avatar
Rick Ho committed
50
	moe_cuda_expert_exchange_impl(
51
52
53
54
55
			local_expert_count.data_ptr<long>(),
			global_expert_count.data_ptr<long>(),
			num_expert, n_workers,
			smgr);
	return {global_expert_count};
Rick Ho's avatar
Rick Ho committed
56
57
58
59
60
}

template<typename scalar_t>
void moe_cuda_global_scatter_impl(
	const scalar_t* local_input_buf,
61
62
	const long* local_expert_count,
	const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
63
64
65
66
67
68
	scalar_t* input_buf,
	size_t in_feat, size_t num_expert, size_t world_size,
	CudaStreamManager* smgr) {
	// assert world_size > 1
	int recv_ptr = 0;
	/* TODO: may save for backward */
69
	long*expert_ptr = new long[num_expert * world_size];
Rick Ho's avatar
Rick Ho committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert * world_size; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
	}

	for (int i = 0; i < num_expert; ++i) {
		NCCL_SAFE_CALL(ncclGroupStart());
		for (int j = 0; j < world_size; ++j) {
			int idx = i + j * num_expert;
			if (local_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclSend(
						local_input_buf + expert_ptr[idx] * in_feat, 
						local_expert_count[idx] * in_feat * sizeof(scalar_t),
						ncclChar, 
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
			}
			if (global_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclRecv(
						input_buf + recv_ptr * in_feat,
						global_expert_count[idx] * in_feat * sizeof(scalar_t),
						ncclChar,
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
				recv_ptr += global_expert_count[idx];
			}
		}
		NCCL_SAFE_CALL(ncclGroupEnd());
	}
	delete [] expert_ptr;
	smgr->sync(1);
}

std::vector<torch::Tensor> moe_cuda_global_scatter(
		torch::Tensor input_buf,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		long batch_size, long n_workers) {
	auto num_expert = local_expert_count.size(0) / n_workers;
	auto in_feat = input_buf.size(1);
    auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
	auto smgr = getCudaStreamManager(input_buf.device().index());

Rick Ho's avatar
Rick Ho committed
115
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), 
Rick Ho's avatar
Rick Ho committed
116
117
118
			"moe_cuda_global_scatter", ([&] {
		moe_cuda_global_scatter_impl<scalar_t>(
			input_buf.data_ptr<scalar_t>(),
119
120
			local_expert_count.data_ptr<long>(),
			global_expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
121
122
123
124
125
126
127
128
129
130
131
			global_input_buf.data_ptr<scalar_t>(),
			in_feat, num_expert, n_workers,
			smgr
		);
	}));
	return {global_input_buf,};
}

template<typename scalar_t>
void moe_cuda_global_gather_impl(
	const scalar_t* output_buf,
132
133
	const long* local_expert_count,
	const long* global_expert_count,
Rick Ho's avatar
Rick Ho committed
134
135
136
	scalar_t* local_output_buf,
	size_t out_feat, size_t num_expert, size_t world_size,
	CudaStreamManager* smgr) {
137
	long send_ptr = 0;
Rick Ho's avatar
Rick Ho committed
138
	/* TODO: may save for backward */
139
	long *expert_ptr = new long[num_expert * world_size];
Rick Ho's avatar
Rick Ho committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert * world_size; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
	}

	for (int i = 0; i < num_expert; ++i) {
		NCCL_SAFE_CALL(ncclGroupStart());
		for (int j = 0; j < world_size; ++j) {
			int idx = i + j * num_expert;
			if (global_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclSend(
						output_buf + send_ptr * out_feat,
						global_expert_count[idx] * out_feat * sizeof(scalar_t),
						ncclChar,
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
				send_ptr += global_expert_count[idx];
			}
			if (local_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclRecv(
						local_output_buf + expert_ptr[idx] * out_feat, 
						local_expert_count[idx] * out_feat * sizeof(scalar_t),
						ncclChar, 
						j,
						smgr->ncclcomm,
						smgr->stream(0)));
			}
		}
		NCCL_SAFE_CALL(ncclGroupEnd());
	}
	delete [] expert_ptr;
	smgr->sync(1);
}

std::vector<torch::Tensor> moe_cuda_global_gather(
		torch::Tensor output_buf,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		long batch_size, long n_workers) {
	auto num_expert = local_expert_count.size(0) / n_workers;
	auto out_feat = output_buf.size(1);
    auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
	auto smgr = getCudaStreamManager(output_buf.device().index());

Rick Ho's avatar
Rick Ho committed
185
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), 
Rick Ho's avatar
Rick Ho committed
186
187
188
			"moe_cuda_global_gather", ([&] {
		moe_cuda_global_gather_impl<scalar_t>(
			output_buf.data_ptr<scalar_t>(),
189
190
			local_expert_count.data_ptr<long>(),
			global_expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
191
192
193
194
195
196
197
198
			local_output_buf.data_ptr<scalar_t>(),
			out_feat, num_expert, n_workers,
			smgr
		);
	}));
	return {local_output_buf,};
}

199
void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
Rick Ho's avatar
Rick Ho committed
200
	auto smgr = getCudaStreamManager(t.device().index());
201
202
203
	smgr->ensure((void*)&p, t.device());
}

Rick Ho's avatar
Rick Ho committed
204
#endif