Commit 7c3e5149 authored by Rick Ho's avatar Rick Ho
Browse files

separate kernel files

parent 969ef607
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
void moe_cuda_expert_exchange_impl(
const int* local_expert_count,
int* global_expert_count,
int* fwd_expert_count,
int num_expert, int world_size) {
MPI_Alltoall(local_expert_count, num_expert, MPI_INT,
global_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < world_size; ++j) {
fwd_expert_count[i] += global_expert_count[i + j * num_expert];
}
}
}
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto fwe_options = torch::TensorOptions()
.dtype(local_expert_count.dtype());
auto fwd_expert_count = torch::zeros({num_expert}, fwe_options);
moe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
fwd_expert_count.data_ptr<int>(),
num_expert, n_workers);
return {global_expert_count, fwd_expert_count};
}
template<typename scalar_t>
void moe_cuda_global_scatter_impl(
const scalar_t* local_input_buf,
const int* local_expert_count,
const int* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto in_feat = input_buf.size(1);
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(),
"moe_cuda_global_scatter", ([&] {
moe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, num_expert, n_workers,
smgr
);
}));
return {global_input_buf,};
}
template<typename scalar_t>
void moe_cuda_global_gather_impl(
const scalar_t* output_buf,
const int* local_expert_count,
const int* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
int send_ptr = 0;
/* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(),
"moe_cuda_global_gather", ([&] {
moe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers,
smgr
);
}));
return {local_output_buf,};
}
#endif
...@@ -10,14 +10,8 @@ ...@@ -10,14 +10,8 @@
#include <helper_cuda.h> #include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
#endif
#include "timer.hh" #include "timer.hh"
#include "cublas_wrapper.h" #include "cublas_wrapper.h"
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
...@@ -80,181 +74,6 @@ void moe_cuda_expert_count_impl( ...@@ -80,181 +74,6 @@ void moe_cuda_expert_count_impl(
delete [] expert_ptr; delete [] expert_ptr;
} }
#ifdef MOE_USE_NCCL
void moe_cuda_expert_exchange_impl(
const int* local_expert_count,
int* global_expert_count,
int* fwd_expert_count,
int num_expert, int world_size) {
MPI_Alltoall(local_expert_count, num_expert, MPI_INT,
global_expert_count, num_expert, MPI_INT, MPI_COMM_WORLD);
for (int i = 0; i < num_expert; ++i) {
for (int j = 0; j < world_size; ++j) {
fwd_expert_count[i] += global_expert_count[i + j * num_expert];
}
}
}
std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch::Tensor local_expert_count,
long num_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto fwe_options = torch::TensorOptions()
.dtype(local_expert_count.dtype());
auto fwd_expert_count = torch::zeros({num_expert}, fwe_options);
moe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
fwd_expert_count.data_ptr<int>(),
num_expert, n_workers);
return {global_expert_count, fwd_expert_count};
}
template<typename scalar_t>
void moe_cuda_global_scatter_impl(
const scalar_t* local_input_buf,
const int* local_expert_count,
const int* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto in_feat = input_buf.size(1);
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(),
"moe_cuda_global_scatter", ([&] {
moe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, num_expert, n_workers,
smgr
);
}));
return {global_input_buf,};
}
template<typename scalar_t>
void moe_cuda_global_gather_impl(
const scalar_t* output_buf,
const int* local_expert_count,
const int* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t num_expert, size_t world_size,
CudaStreamManager* smgr) {
int send_ptr = 0;
/* TODO: may save for backward */
int *expert_ptr = new int[num_expert * world_size];
expert_ptr[0] = 0;
for (int i = 1; i < num_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (int i = 0; i < num_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < world_size; ++j) {
int idx = i + j * num_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}
std::vector<torch::Tensor> moe_cuda_global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
auto num_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(),
"moe_cuda_global_gather", ([&] {
moe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<int>(),
global_expert_count.data_ptr<int>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, num_expert, n_workers,
smgr
);
}));
return {local_output_buf,};
}
#endif // MOE_USE_NCCL
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_local_scatter_impl( void moe_cuda_local_scatter_impl(
const scalar_t* input, const scalar_t* input,
......
#include "moe_cuda_kernel.h"
#include <cstdio>
#include <iostream>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
// TODO
#endif
...@@ -17,7 +17,9 @@ setup( ...@@ -17,7 +17,9 @@ setup(
sources=[ sources=[
'moe.cpp', 'moe.cpp',
'cuda_stream_manager.cpp', 'cuda_stream_manager.cpp',
'moe_cuda_kernel.cu', 'moe_compute_kernel.cu',
'moe_comm_kernel.cu',
'moe_fused_kernel.cu',
], ],
extra_compile_args={ extra_compile_args={
'cxx': cxx_flags, 'cxx': cxx_flags,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment