moe_fused_kernel.cu 3.95 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#include "moe_cuda_kernel.h"

#include <cstdio>
#include <iostream>
#include <vector>

#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h> 
#include <c10/cuda/CUDAGuard.h>

#include "cuda_stream_manager.h"
Rick Ho's avatar
Rick Ho committed
14
#include "cublas_wrapper.h"
Rick Ho's avatar
Rick Ho committed
15
16
17
18

#ifdef MOE_USE_NCCL
#include <nccl.h>

Rick Ho's avatar
Rick Ho committed
19
20
21
22
23
24
25
template<typename scalar_t>
void moe_cuda_global_fused_forward_impl(
		const scalar_t* input_buf,
		const scalar_t* weight,
		scalar_t* global_input_buf,
		scalar_t* global_output_buf,
		scalar_t* output_buf,
26
27
		const long* local_expert_count, 
		const long* global_expert_count, 
Rick Ho's avatar
Rick Ho committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
		long in_feat, long out_feat, 
		long num_expert, long world_size,
		CudaStreamManager* smgr) {

	int ptr = 0;
	int send_ptr = 0;
	int recv_ptr = 0;

	int *expert_ptr = new int[num_expert * world_size];
	expert_ptr[0] = 0;
	for (int i = 1; i < num_expert * world_size; ++i) {
		expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
	}

	scalar_t alpha = 1, beta = 0; 

	for (int i = 0; i < num_expert; ++i) {
		int expert_count = 0;
		NCCL_SAFE_CALL(ncclGroupStart());
		for (int j = 0; j < world_size; ++j) {
			int idx = i + j * num_expert;
			if (local_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclSend(
						input_buf + expert_ptr[idx] * in_feat, 
						local_expert_count[idx] * in_feat * sizeof(scalar_t),
						ncclChar, 
						j,
						smgr->ncclcomm,
						smgr->stream(i)));
			}
			if (global_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclRecv(
						global_input_buf + recv_ptr * in_feat,
						global_expert_count[idx] * in_feat * sizeof(scalar_t),
						ncclChar,
						j,
						smgr->ncclcomm,
						smgr->stream(i)));
				recv_ptr += global_expert_count[idx];
				expert_count += global_expert_count[idx];
			}
		}
		NCCL_SAFE_CALL(ncclGroupEnd());

		checkCudaErrors(cublasXgemm(
				smgr->handle(i),
				CUBLAS_OP_T,
				CUBLAS_OP_N,
				out_feat, expert_count, in_feat,
				&alpha,
				weight + i * in_feat * out_feat, in_feat,
				global_input_buf + ptr * in_feat, in_feat,
				&beta,
				global_output_buf + out_feat * ptr, out_feat
				));

		ptr += expert_count;

		NCCL_SAFE_CALL(ncclGroupStart());
		for (int j = 0; j < world_size; ++j) {
			int idx = i + j * num_expert;
			if (global_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclSend(
						global_output_buf + send_ptr * out_feat,
						global_expert_count[idx] * out_feat * sizeof(scalar_t),
						ncclChar,
						j,
						smgr->ncclcomm,
						smgr->stream(i)));
				send_ptr += global_expert_count[idx];
			}
			if (local_expert_count[idx]) {
				NCCL_SAFE_CALL(ncclRecv(
						output_buf + expert_ptr[idx] * out_feat, 
						local_expert_count[idx] * out_feat * sizeof(scalar_t),
						ncclChar, 
						j,
						smgr->ncclcomm,
						smgr->stream(i)));
			}
		}
		NCCL_SAFE_CALL(ncclGroupEnd());
	}
	delete [] expert_ptr;
	smgr->sync(num_expert);
}

std::vector<torch::Tensor> moe_cuda_global_fused_forward(
		torch::Tensor input_buf,
        torch::Tensor weight,
		torch::Tensor local_expert_count,
		torch::Tensor global_expert_count,
		long global_batch_size, long local_batch_size, long n_workers) {
	const auto num_expert = local_expert_count.size(0) / n_workers;
	const auto out_feat = weight.size(1);
    const auto in_feat = weight.size(2);

	auto smgr = getCudaStreamManager(input_buf.device().index());

    auto global_input_buf = input_buf.new_empty({global_batch_size, in_feat});
    auto global_output_buf = input_buf.new_empty({global_batch_size, out_feat});
    auto output_buf = input_buf.new_empty({local_batch_size, out_feat});
	AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), 
			"moe_cuda_global_fused_forward", ([&] {
		moe_cuda_global_fused_forward_impl(
			input_buf.data_ptr<scalar_t>(),
			weight.data_ptr<scalar_t>(),
			global_input_buf.data_ptr<scalar_t>(),
			global_output_buf.data_ptr<scalar_t>(),
			output_buf.data_ptr<scalar_t>(),
138
139
			local_expert_count.data_ptr<long>(),
			global_expert_count.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
140
141
142
143
144
			in_feat, out_feat, num_expert, n_workers,
			smgr);
	}));
	return {output_buf, global_input_buf};
}
Rick Ho's avatar
Rick Ho committed
145
146
147

#endif