Unverified Commit 3c42c892 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #21 from bias_improvement

Bias improvement #15
parents 26824495 41cfe06c
...@@ -12,101 +12,108 @@ ...@@ -12,101 +12,108 @@
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count( std::vector<torch::Tensor> moe_expert_count(
torch::Tensor gate, torch::Tensor gate,
size_t num_expert) { size_t num_expert) {
CHECK_INPUT(gate); CHECK_INPUT(gate);
return moe_cuda_expert_count(gate, num_expert); return moe_cuda_expert_count(gate, num_expert);
} }
std::vector<torch::Tensor> moe_local_scatter( std::vector<torch::Tensor> moe_local_scatter(
torch::Tensor input, torch::Tensor input,
torch::Tensor pos) { torch::Tensor pos) {
CHECK_INPUT(input); CHECK_INPUT(input);
return moe_cuda_local_scatter(input, pos); return moe_cuda_local_scatter(input, pos);
} }
std::vector<torch::Tensor> moe_local_gather( std::vector<torch::Tensor> moe_local_gather(
torch::Tensor output_buf, torch::Tensor output_buf,
torch::Tensor pos) { torch::Tensor pos) {
CHECK_INPUT(output_buf); CHECK_INPUT(output_buf);
return moe_cuda_local_gather(output_buf, pos); return moe_cuda_local_gather(output_buf, pos);
} }
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor expert_count, // [num_expert]
torch::Tensor expert_count // [batch_size] torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that // check if bias is valid in case it exists
Wx+b = [W b] [x] if (bias_o.has_value()) {
[1] auto bias = bias_o.value();
*/ CHECK_INPUT(bias);
return moe_cuda_forward(input_buf, weight, expert_count); }
return moe_cuda_forward(input_buf, expert_count, weight, bias_o);
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor expert_count, // [num_expert]
torch::Tensor expert_count torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
CHECK_INPUT(grad_output_buf); CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
/*
The bias term should have been merged into weight. Note the following fact that // check if bias is valid in case it exists
Wx+b = [W b] [x] if (bias_o.has_value()) {
[1] auto bias = bias_o.value();
*/ CHECK_INPUT(bias);
return moe_cuda_backward(grad_output_buf, input_buf, weight, expert_count); }
return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o);
} }
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_expert_exchange( std::vector<torch::Tensor> moe_expert_exchange(
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
size_t num_expert, size_t n_workers) { size_t num_expert, size_t n_workers) {
return moe_cuda_expert_exchange(local_expert_count, num_expert, n_workers); return moe_cuda_expert_exchange(local_expert_count, num_expert, n_workers);
} }
std::vector<torch::Tensor> moe_global_scatter( std::vector<torch::Tensor> moe_global_scatter(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) { size_t batch_size, size_t n_workers) {
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
return moe_cuda_global_scatter(input_buf, return moe_cuda_global_scatter(input_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
batch_size, n_workers); batch_size, n_workers);
} }
std::vector<torch::Tensor> moe_global_gather( std::vector<torch::Tensor> moe_global_gather(
torch::Tensor output_buf, torch::Tensor output_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) { size_t batch_size, size_t n_workers) {
CHECK_INPUT(output_buf); CHECK_INPUT(output_buf);
return moe_cuda_global_gather(output_buf, return moe_cuda_global_gather(output_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
batch_size, n_workers); batch_size, n_workers);
} }
std::vector<torch::Tensor> moe_global_fused_forward( std::vector<torch::Tensor> moe_global_fused_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) { long global_batch_size, long local_batch_size, long n_workers) {
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
return moe_cuda_global_fused_forward( return moe_cuda_global_fused_forward(
input_buf, weight, local_expert_count, global_expert_count, input_buf, weight, local_expert_count, global_expert_count,
global_batch_size, local_batch_size, n_workers); global_batch_size, local_batch_size, n_workers);
} }
#include <c10d/ProcessGroupNCCL.hpp> #include <c10d/ProcessGroupNCCL.hpp>
...@@ -114,47 +121,47 @@ std::vector<torch::Tensor> moe_global_fused_forward( ...@@ -114,47 +121,47 @@ std::vector<torch::Tensor> moe_global_fused_forward(
class HackNCCLGroup: public c10d::ProcessGroupNCCL { class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public: public:
ncclComm_t getcomm(at::Device dev) { ncclComm_t getcomm(at::Device dev) {
auto key = std::to_string(dev.index()); auto key = std::to_string(dev.index());
#ifdef ENABLE_NCCL_P2P_SUPPORT #ifdef ENABLE_NCCL_P2P_SUPPORT
ncclUniqueId ncclID; ncclUniqueId ncclID;
int rank = getRank(); int rank = getRank();
if (rank == 0) { if (rank == 0) {
ncclGetUniqueId(&ncclID); ncclGetUniqueId(&ncclID);
} }
broadcastUniqueNCCLID(&ncclID, broadcastUniqueNCCLID(&ncclID,
c10d::OpType::SEND, c10d::OpType::SEND,
"fastmoe_nccl_comm", "fastmoe_nccl_comm",
rank); rank);
ncclComm_t comm; ncclComm_t comm;
ncclCommInitRank(&comm, getSize(), ncclID, rank); ncclCommInitRank(&comm, getSize(), ncclID, rank);
return comm; return comm;
#else #else
auto v = getNCCLComm(key, {dev}); auto v = getNCCLComm(key, {dev});
if (v.size() == 0) { if (v.size() == 0) {
std::cerr << "PyTorch has nothing\n"; std::cerr << "PyTorch has nothing\n";
return 0; return 0;
} }
int count; int count;
ncclCommCount(v[0]->getNcclComm(), &count); ncclCommCount(v[0]->getNcclComm(), &count);
std::cerr << "PyTorch has " << v.size() << " comms, comm 0 size " << count << "\n"; std::cerr << "PyTorch has " << v.size() << " comms, comm 0 size " << count << "\n";
return v[0]->getNcclComm(); return v[0]->getNcclComm();
#endif #endif
} }
}; };
void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) { void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(t.device().index()); auto smgr = getCudaStreamManager(t.device().index());
if (smgr->ncclgood) { if (smgr->ncclgood) {
return; return;
} }
HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p; HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
smgr->ncclcomm = h->getcomm(t.device()); smgr->ncclcomm = h->getcomm(t.device());
if (smgr->ncclcomm != 0) { if (smgr->ncclcomm != 0) {
smgr->ncclgood = 1; smgr->ncclgood = 1;
} else { } else {
std::cerr << "Nccl initialization failed\n"; std::cerr << "Nccl initialization failed\n";
} }
} }
#endif // MOE_USE_NCCL #endif // MOE_USE_NCCL
...@@ -167,8 +174,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -167,8 +174,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("expert_exchange", &moe_expert_exchange, "MoE expert exchange (CUDA)"); m.def("expert_exchange", &moe_expert_exchange, "MoE expert exchange (CUDA)");
m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)"); m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)"); m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
m.def("global_fused_forward", &moe_global_fused_forward, m.def("global_fused_forward", &moe_global_fused_forward,
"MoE global gather (CUDA)"); "MoE global gather (CUDA)");
m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm"); m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm");
#endif #endif
m.def("forward", &moe_forward, "MoE forward (CUDA)"); m.def("forward", &moe_forward, "MoE forward (CUDA)");
......
...@@ -19,304 +19,386 @@ ...@@ -19,304 +19,386 @@
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
const long* offset, const scalar_t** ptrs) { const long* offset, const scalar_t** ptrs) {
size_t idx = threadIdx.x + blockDim.x * blockIdx.x; size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < n) { if (idx < n) {
ptrs[idx] = base + stride * offset[idx]; ptrs[idx] = base + stride * offset[idx];
} }
} }
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void batch_scatter_kernel(size_t wid, const long* pos, void batch_scatter_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) { const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * pos[blockIdx.x]; inbuf += wid * pos[blockIdx.x];
oubuf += wid * blockIdx.x; oubuf += wid * blockIdx.x;
for (int i = threadIdx.x; i < wid; i += blockDim.x) { for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i]; oubuf[i] = inbuf[i];
} }
} }
/*
This function is to be called with one block per each column
*/
template <typename scalar_t>
__global__
void column_reduce(const scalar_t * matrix, scalar_t * result,
int m /* lines */, int n /* columns*/) {
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern __shared__ unsigned char my_smem[];
scalar_t *sdata = reinterpret_cast<scalar_t *>(my_smem);
// normal tid
int tid = threadIdx.x + threadIdx.y * blockDim.x;
// transposed tid for shared memory
int new_tid = threadIdx.y + threadIdx.x * blockDim.y;
// true x value in the matrix
int real_x = threadIdx.x + blockDim.x * blockIdx.x;
int i = real_x + n * threadIdx.y;
const int it = n*blockDim.y;
int offset = it;
float accumulator = 0;
if (threadIdx.y < m && real_x < n) {
// store all the values from this column in a warped way
accumulator = matrix[i];
while (i + offset < n*m) {
accumulator += matrix[i + offset];
offset += it;
}
}
// save column reduction data in a transposed way
sdata[new_tid] = accumulator;
__syncthreads();
for (size_t t= 16; t > 0; t>>=1) {
if (tid < 32 * 32 - 16)
sdata[tid] += sdata[tid + t];
__syncthreads();
}
if (threadIdx.y == 0 && real_x < n)
result[real_x] = sdata[new_tid];
}
void moe_cuda_expert_count_impl( void moe_cuda_expert_count_impl(
const int* d_gate, const int* d_gate,
int* expert_count, int* expert_count,
int* d_pos, int* d_pos,
const size_t num_expert, const size_t num_expert,
const size_t batch_size) { const size_t batch_size) {
int *gate = new int[batch_size]; int *gate = new int[batch_size];
int *expert_ptr = new int[num_expert]; int *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert); memset(expert_count, 0, sizeof(int) * num_expert);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]]; ++expert_count[gate[i]];
} }
expert_ptr[0] = 0; expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) { for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1]; expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
} }
int *pos = new int[batch_size]; int *pos = new int[batch_size];
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
pos[i] = expert_ptr[gate[i]]++; pos[i] = expert_ptr[gate[i]]++;
} }
for (int i = num_expert - 1; i > 0; --i) { for (int i = num_expert - 1; i > 0; --i) {
expert_ptr[i] = expert_ptr[i - 1]; expert_ptr[i] = expert_ptr[i - 1];
} }
expert_ptr[0] = 0; expert_ptr[0] = 0;
checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(d_pos, pos, sizeof(int) * batch_size,
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
delete [] gate; delete [] gate;
delete [] expert_ptr; delete [] expert_ptr;
} }
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_local_scatter_impl( void moe_cuda_local_scatter_impl(
const scalar_t* input, const scalar_t* input,
const long* d_pos, const long* d_pos,
scalar_t* input_buf, scalar_t* input_buf,
const long batch_size, const long batch_size,
const long in_feat, const long in_feat,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
batch_scatter_kernel<scalar_t> batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input, <<<batch_size, 256, 0, smgr->stream(0)>>>(in_feat, d_pos, input,
input_buf); input_buf);
smgr->sync(1); smgr->sync(1);
} }
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void batch_gather_kernel(size_t wid, const long* pos, void batch_gather_kernel(size_t wid, const long* pos,
const scalar_t* inbuf, scalar_t* oubuf) { const scalar_t* inbuf, scalar_t* oubuf) {
inbuf += wid * blockIdx.x; inbuf += wid * blockIdx.x;
oubuf += wid * pos[blockIdx.x]; oubuf += wid * pos[blockIdx.x];
for (int i = threadIdx.x; i < wid; i += blockDim.x) { for (int i = threadIdx.x; i < wid; i += blockDim.x) {
oubuf[i] = inbuf[i]; oubuf[i] = inbuf[i];
} }
} }
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_local_gather_impl( void moe_cuda_local_gather_impl(
const scalar_t* output_buf, const scalar_t* output_buf,
const long* d_pos, const long* d_pos,
scalar_t* output, scalar_t* output,
const size_t batch_size, const size_t batch_size,
const size_t out_feat, const size_t out_feat,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf, <<<batch_size, 256, 0, smgr->stream(0)>>>(out_feat, d_pos, output_buf,
output); output);
smgr->sync(1); smgr->sync(1);
} }
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_forward_impl( void moe_cuda_forward_impl(
const scalar_t* input_buf, const scalar_t* input_buf,
const scalar_t* weight, const scalar_t* weight,
const long* expert_count, const long* expert_count,
scalar_t* output_buf, scalar_t* output_buf,
const bool has_bias,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = has_bias ? 1 : 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) { if (expert_count[i] == 0) {
continue; continue;
} }
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm( checkCudaErrors(cublasXgemm(
smgr->handle(i), smgr->handle(i),
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
out_feat, expert_count[i], in_feat, out_feat, expert_count[i], in_feat,
&alpha, &alpha,
weight + i * in_feat * out_feat, in_feat, weight + i * in_feat * out_feat, in_feat,
input_buf + ptr * in_feat, in_feat, input_buf + ptr * in_feat, in_feat,
&beta, &beta,
output_buf + out_feat * ptr, out_feat output_buf + out_feat * ptr, out_feat
)); ));
ptr += expert_count[i]; ptr += expert_count[i];
} }
smgr->sync(num_expert); smgr->sync(num_expert);
} }
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_backward_impl( void moe_cuda_backward_impl(
const scalar_t* grad_output_buf, const scalar_t* grad_output_buf,
const scalar_t* input_buf, const scalar_t* input_buf,
const scalar_t* weight, const scalar_t* weight,
const long* expert_count, const long* expert_count,
scalar_t* grad_input_buf, scalar_t* grad_input_buf,
scalar_t* grad_weight, scalar_t* grad_weight,
scalar_t* grad_bias,
const bool has_bias,
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { // bias
if (expert_count[i] == 0) { dim3 block_threads(32, 32);
cudaMemset(grad_weight + i * in_feat * out_feat, 0, dim3 grid_threads(out_feat / 32 + (out_feat % 32 ? 1 : 0), 1);
sizeof(scalar_t) * in_feat * out_feat);
continue;
} for (int i = 0, ptr = 0; i < num_expert; ++i) {
// Use T(B) x T(A) = T(C) to produce row-major C if (expert_count[i] == 0) {
cudaMemset(grad_weight + i * in_feat * out_feat, 0,
// Backward input: g_i = w @ g_o sizeof(scalar_t) * in_feat * out_feat);
checkCudaErrors(cublasXgemm( cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
smgr->handle(i), continue;
CUBLAS_OP_N, }
CUBLAS_OP_N, // Use T(B) x T(A) = T(C) to produce row-major C
in_feat, expert_count[i], out_feat,
&alpha, // Backward input: g_i = w @ g_o
weight + i * in_feat * out_feat, in_feat, checkCudaErrors(cublasXgemm(
grad_output_buf + ptr * out_feat, out_feat, smgr->handle(i),
&beta, CUBLAS_OP_N,
grad_input_buf + in_feat * ptr, in_feat CUBLAS_OP_N,
)); in_feat, expert_count[i], out_feat,
&alpha,
// Backward weight: g_w = i @ g_o weight + i * in_feat * out_feat, in_feat,
checkCudaErrors(cublasXgemm( grad_output_buf + ptr * out_feat, out_feat,
smgr->handle(i), &beta,
CUBLAS_OP_N, grad_input_buf + in_feat * ptr, in_feat
CUBLAS_OP_T, ));
in_feat, out_feat, expert_count[i],
&alpha, // Backward weight: g_w = i @ g_o
input_buf + in_feat * ptr, in_feat, checkCudaErrors(cublasXgemm(
grad_output_buf + ptr * out_feat, out_feat, smgr->handle(i),
&beta, CUBLAS_OP_N,
grad_weight + i * in_feat * out_feat, in_feat CUBLAS_OP_T,
)); in_feat, out_feat, expert_count[i],
&alpha,
ptr += expert_count[i]; input_buf + in_feat * ptr, in_feat,
} grad_output_buf + ptr * out_feat, out_feat,
smgr->sync(num_expert); &beta,
grad_weight + i * in_feat * out_feat, in_feat
));
if (has_bias) {
column_reduce
<<<grid_threads, block_threads, sizeof(scalar_t)*1024, smgr->stream(0)>>>
(
grad_output_buf + ptr * out_feat,
grad_bias + i * out_feat,
expert_count[i],
out_feat
);
}
ptr += expert_count[i];
}
smgr->sync(num_expert);
} }
std::vector<torch::Tensor> moe_cuda_expert_count( std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor gate, torch::Tensor gate,
size_t num_expert) { size_t num_expert) {
const auto batch_size = gate.size(0); const auto batch_size = gate.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32); auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options); auto expert_count = torch::empty(num_expert, ec_options);
auto pos_options = torch::TensorOptions() auto pos_options = torch::TensorOptions()
.device(gate.device()) .device(gate.device())
.dtype(torch::kInt32); .dtype(torch::kInt32);
auto pos = torch::empty(batch_size, pos_options); auto pos = torch::empty(batch_size, pos_options);
moe_cuda_expert_count_impl( moe_cuda_expert_count_impl(
gate.data_ptr<int>(), gate.data_ptr<int>(),
expert_count.data_ptr<int>(), expert_count.data_ptr<int>(),
pos.data_ptr<int>(), pos.data_ptr<int>(),
num_expert, num_expert,
batch_size); batch_size);
return {expert_count, pos}; return {expert_count, pos};
} }
std::vector<torch::Tensor> moe_cuda_local_scatter( std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input, torch::Tensor input,
torch::Tensor pos) { torch::Tensor pos) {
auto smgr = getCudaStreamManager(input.device().index()); auto smgr = getCudaStreamManager(input.device().index());
const auto batch_size = pos.size(0); const auto batch_size = pos.size(0);
const auto in_feat = input.size(1); const auto in_feat = input.size(1);
auto opt = torch::TensorOptions() auto opt = torch::TensorOptions()
.dtype(input.dtype()) .dtype(input.dtype())
.device(input.device()); .device(input.device());
auto input_buf = torch::empty({batch_size, in_feat}, opt); auto input_buf = torch::empty({batch_size, in_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda",
([&] { ([&] {
moe_cuda_local_scatter_impl<scalar_t>( moe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
pos.data_ptr<long>(), pos.data_ptr<long>(),
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
batch_size, batch_size,
in_feat, in_feat,
smgr); smgr);
})); }));
return {input_buf,}; return {input_buf,};
} }
std::vector<torch::Tensor> moe_cuda_local_gather( std::vector<torch::Tensor> moe_cuda_local_gather(
torch::Tensor output_buf, torch::Tensor output_buf,
torch::Tensor pos) { torch::Tensor pos) {
auto smgr = getCudaStreamManager(output_buf.device().index()); auto smgr = getCudaStreamManager(output_buf.device().index());
const auto batch_size = pos.size(0); const auto batch_size = pos.size(0);
const auto out_feat = output_buf.size(1); const auto out_feat = output_buf.size(1);
auto opt = torch::TensorOptions() auto opt = torch::TensorOptions()
.dtype(output_buf.dtype()) .dtype(output_buf.dtype())
.device(output_buf.device()); .device(output_buf.device());
auto output = torch::empty({batch_size, out_feat}, opt); auto output = torch::empty({batch_size, out_feat}, opt);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda",
([&] { ([&] {
moe_cuda_local_gather_impl<scalar_t>( moe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(), output_buf.data_ptr<scalar_t>(),
pos.data_ptr<long>(), pos.data_ptr<long>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
batch_size, batch_size,
out_feat, out_feat,
smgr); smgr);
})); }));
return {output,}; return {output,};
} }
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor expert_count at::optional<torch::Tensor> bias
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat); num_expert, in_feat, out_feat);
#endif #endif
auto out_options = torch::TensorOptions()
.device(input_buf.device()) torch::Tensor output;
.dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options);
if (bias.has_value()) {
output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
} else{
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
output = torch::empty({batch_size, out_feat}, out_options);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda",
([&] { ([&] {
moe_cuda_forward_impl<scalar_t>( moe_cuda_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(), expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
in_feat, bias.has_value(),
out_feat, in_feat,
num_expert, out_feat,
smgr num_expert,
); smgr
);
})); }));
return {output, }; return {output, };
} }
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x out_feat] torch::Tensor input_buf, // [batch_size x out_feat]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor expert_count,
torch::Tensor expert_count torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias
) { ) {
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0); const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
...@@ -324,28 +406,31 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -324,28 +406,31 @@ std::vector<torch::Tensor> moe_cuda_backward(
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, " printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n", "out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat); batch_size, num_expert, in_feat, out_feat);
#endif #endif
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat}); auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_backward_impl<scalar_t>( moe_cuda_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(), grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(), expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(), grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(), grad_weight.data_ptr<scalar_t>(),
grad_bias.data_ptr<scalar_t>(),
bias.has_value(),
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert,
smgr smgr
); );
})); }));
return {grad_input_buf, grad_weight}; return {grad_input_buf, grad_weight, grad_bias};
} }
...@@ -19,14 +19,16 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -19,14 +19,16 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor expert_count); at::optional<torch::Tensor> bias);
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output_buf, torch::Tensor grad_output_buf,
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor expert_count); at::optional<torch::Tensor> bias);
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
......
...@@ -110,21 +110,25 @@ class MOELinear(Function): ...@@ -110,21 +110,25 @@ class MOELinear(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, global_input_buf, weight, fwd_expert_count): def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
(global_output_buf,) = fmoe_cuda.forward( (global_output_buf,) = fmoe_cuda.forward(
global_input_buf, weight, fwd_expert_count global_input_buf, fwd_expert_count, weight, bias
) )
variables = (global_input_buf, weight, fwd_expert_count) variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return global_output_buf return global_output_buf
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
(input_buf, weight, fwd_expert_count) = ctx.saved_tensors (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight = fmoe_cuda.backward( grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.backward(
grad_out, input_buf, weight, fwd_expert_count grad_out, input_buf, fwd_expert_count, weight, bias
) )
return grad_inp_buf, grad_weight, None
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class MOEGather(Function): class MOEGather(Function):
......
...@@ -41,37 +41,7 @@ class FMoELinear(nn.Module): ...@@ -41,37 +41,7 @@ class FMoELinear(nn.Module):
r""" r"""
Call MOE function Call MOE function
""" """
x = MOELinear.apply(inp, self.weight, fwd_expert_count) x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
if self.bias is not None:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias = torch.repeat_interleave(
self.bias, fwd_expert_count.to(self.bias.device), dim=0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x = x + bias
return x return x
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
...@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict): ...@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict):
@pytest.mark.parametrize("d_model", [16]) @pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32]) @pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("mp_size", [1, 2]) @pytest.mark.parametrize("mp_size", [1, 2])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_linear_distributed( def test_fmoe_linear_distributed(
num_expert, top_k, batch_size, d_model, d_hidden, mp_size num_expert, top_k, batch_size, d_model, d_hidden, mp_size, data_type
): ):
_run_distributed( _run_distributed(
"_test_fmoe_linear", "_test_fmoe_linear",
...@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed( ...@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed(
"d_model": d_model, "d_model": d_model,
"d_hidden": d_hidden, "d_hidden": d_hidden,
"mp_size": mp_size, "mp_size": mp_size,
"data_type": data_type
}, },
) )
...@@ -120,5 +122,6 @@ if __name__ == "__main__": ...@@ -120,5 +122,6 @@ if __name__ == "__main__":
else: else:
test_fmoe_local_ddp(mp_size=2) test_fmoe_local_ddp(mp_size=2)
test_fmoe_linear_distributed( test_fmoe_linear_distributed(
num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2 num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2,
data_type="torch.HalfTensor"
) )
...@@ -17,15 +17,15 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert ...@@ -17,15 +17,15 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
def _perform_forward( def _perform_forward(
moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group, data_type='torch.FloatTensor'
): ):
moe.zero_grad() moe.zero_grad()
moe_raw.zero_grad() moe_raw.zero_grad()
if not mp_group:
inp = torch.rand(batch_size, d_model).cuda() inp = torch.rand(batch_size, d_model).type(data_type).cuda()
else:
if mp_group:
group_sender = rank // mp_group.size() * mp_group.size() group_sender = rank // mp_group.size() * mp_group.size()
inp = torch.rand(batch_size, d_model).cuda()
torch.distributed.broadcast(inp, group_sender, group=mp_group) torch.distributed.broadcast(inp, group_sender, group=mp_group)
torch.distributed.broadcast( torch.distributed.broadcast(
moe.gate.gate.weight.data, group_sender, group=mp_group moe.gate.gate.weight.data, group_sender, group=mp_group
...@@ -49,15 +49,17 @@ def _perform_forward( ...@@ -49,15 +49,17 @@ def _perform_forward(
return moe_out, raw_out, inp.grad, inp_raw.grad return moe_out, raw_out, inp.grad, inp_raw.grad
def _assert_numercial(names, moe_out_list, raw_out_list, rank): def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
for name, mo, ro in zip(names, moe_out_list, raw_out_list): for name, mo, ro in zip(names, moe_out_list, raw_out_list):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print("Rank {} {} abs err {}".format(rank, name, err)) print("Rank {} {} abs err {}".format(rank, name, err))
if err > 1e-3: if err > precision:
sys.stderr.write(f"=========== {name} moe out ==============\n") sys.stderr.write(f"=========== {name} moe out ==============\n")
sys.stderr.write("{}\n".format(mo)) sys.stderr.write("{}\n".format(mo))
sys.stderr.write(f"=========== {name} raw out ==============\n") sys.stderr.write(f"=========== {name} raw out ==============\n")
sys.stderr.write("{}\n".format(ro)) sys.stderr.write("{}\n".format(ro))
sys.stderr.write(f"=========== {name} diff ==============\n")
sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
assert False assert False
...@@ -90,6 +92,7 @@ class MyMoE(FMoE): ...@@ -90,6 +92,7 @@ class MyMoE(FMoE):
@pytest.mark.parametrize("mp_group", [None]) @pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None]) @pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None]) @pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
def test_fmoe_linear( def test_fmoe_linear(
num_expert, num_expert,
top_k, top_k,
...@@ -101,6 +104,7 @@ def test_fmoe_linear( ...@@ -101,6 +104,7 @@ def test_fmoe_linear(
mp_group, mp_group,
dp_group, dp_group,
world_group, world_group,
data_type,
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
): ):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
...@@ -108,7 +112,7 @@ def test_fmoe_linear( ...@@ -108,7 +112,7 @@ def test_fmoe_linear(
moe = MyMoE( moe = MyMoE(
num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
).cuda() ).type(data_type).cuda()
moe_raw = BruteForceMoELinear( moe_raw = BruteForceMoELinear(
activation=activation, activation=activation,
...@@ -117,7 +121,7 @@ def test_fmoe_linear( ...@@ -117,7 +121,7 @@ def test_fmoe_linear(
d_hidden=d_hidden, d_hidden=d_hidden,
world_size=world_size, world_size=world_size,
top_k=top_k, top_k=top_k,
).cuda() ).type(data_type).cuda()
if world_size == 1: if world_size == 1:
moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone() moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
...@@ -148,7 +152,7 @@ def test_fmoe_linear( ...@@ -148,7 +152,7 @@ def test_fmoe_linear(
moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0) moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward( moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type=data_type
) )
moe_out_list = ( moe_out_list = (
...@@ -198,7 +202,10 @@ def test_fmoe_linear( ...@@ -198,7 +202,10 @@ def test_fmoe_linear(
"h4toh bias grad", "h4toh bias grad",
] ]
_assert_numercial(names, moe_out_list, raw_out_list, rank)
precision = 5e-1 if data_type == 'torch.HalfTensor' else 1e-3
_assert_numerical(names, moe_out_list, raw_out_list, rank, precision=precision)
@pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("batch_size", [4])
...@@ -299,7 +306,7 @@ def test_fmoe( ...@@ -299,7 +306,7 @@ def test_fmoe(
raw_out_list = [raw_out, raw_grad, raw_grad_in] raw_out_list = [raw_out, raw_grad, raw_grad_in]
names = ["forward", "backward", "grad_in"] names = ["forward", "backward", "grad_in"]
_assert_numercial(names, moe_out_list, raw_out_list, rank) _assert_numerical(names, moe_out_list, raw_out_list, rank)
class MyModule(nn.Module): class MyModule(nn.Module):
...@@ -375,7 +382,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group): ...@@ -375,7 +382,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
names = ["mp grad", "dp grad", "wp grad"] names = ["mp grad", "dp grad", "wp grad"]
_assert_numercial(names, ddp_out_list, raw_out_list, rank) _assert_numerical(names, ddp_out_list, raw_out_list, rank)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment