Unverified Commit 1f82fb16 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #173 from laekov/fit-new-smgr

Fit old code with new smgr
parents 2bd187cb 945004e7
...@@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once(
} }
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker); long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr); fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
smgr->syncTorch();
long *gec = _d2h(d_gec, n_worker); long *gec = _d2h(d_gec, n_worker);
/* Limit number of incoming samples */ /* Limit number of incoming samples */
...@@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once(
/* Send limit information back */ /* Send limit information back */
_h2d(gec, d_gec, n_worker); _h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr); fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
smgr->syncTorch();
_d2h(d_lec, lec, n_worker); _d2h(d_lec, lec, n_worker);
auto d_dropcount = _h2d(drop_count, n_worker); auto d_dropcount = _h2d(drop_count, n_worker);
ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum, ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
smgr->ncclcomm, smgr->stream()); smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
_d2h(d_dropcount, drop_count, n_worker); _d2h(d_dropcount, drop_count, n_worker);
auto d_gcap = _cudamalloc<long>(n_worker); auto d_gcap = _cudamalloc<long>(n_worker);
_h2d(&cap, d_gcap + rank, 1); _h2d(&cap, d_gcap + rank, 1);
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64, ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->stream()); smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
auto gcap = _d2h(d_gcap, n_worker); auto gcap = _d2h(d_gcap, n_worker);
/* Re-assign and update counters */ /* Re-assign and update counters */
......
...@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap, ...@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(n_worker, 1024), n_expert); dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024); dim3 block_dim(1024);
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>( limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
ec, cap, eca, n_expert, n_worker); ec, cap, eca, n_expert, n_worker);
smgr->sync(1);
} }
__global__ __global__
...@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx, ...@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024)); dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024); dim3 block_dim(1024);
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>( prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker
); );
smgr->sync(1);
} }
...@@ -44,10 +44,9 @@ void _reduce_grad( ...@@ -44,10 +44,9 @@ void _reduce_grad(
long expert_size) { long expert_size) {
auto smgr = getCudaStreamManager(t.device().index()); auto smgr = getCudaStreamManager(t.device().index());
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEvent_t evt_stash; cudaEvent_t evt_stash;
cudaEventCreate(&evt_stash); cudaEventCreate(&evt_stash);
cudaEventRecord(evt_stash, torch_stream); cudaEventRecord(evt_stash, smgr->torchStream());
FMOE_SWE(smgr->stream(0), evt_stash); FMOE_SWE(smgr->stream(0), evt_stash);
cudaEventDestroy(evt_stash); cudaEventDestroy(evt_stash);
......
...@@ -122,8 +122,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -122,8 +122,7 @@ void fmoe_cuda_fused_forward_impl(
long d_model, long d_model,
long num_expert, long rank, long world_size, long expert_size, long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) { long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); smgr->syncTorch();
cudaStreamSynchronize(torch_stream);
int *local_ptr = new int[num_expert * world_size + 1]; int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1];
...@@ -192,7 +191,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -192,7 +191,7 @@ void fmoe_cuda_fused_forward_impl(
// C_0 ... C_n // C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]); FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]); FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base]; long offset = global_ptr[ei * world_size + from_base];
...@@ -203,14 +202,14 @@ void fmoe_cuda_fused_forward_impl( ...@@ -203,14 +202,14 @@ void fmoe_cuda_fused_forward_impl(
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
} }
cudaEventRecord(output_ready[step], smgr->stream(0)); cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], torch_stream); cudaEventRecord(output_torch_ready[step], smgr->torchStream());
} }
// Compute over shadowed experts // Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) { for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) { if (stored_models[i]) {
FMOE_SWE(smgr->stream(0), evt_shadow[si]); FMOE_SWE(smgr->stream(0), evt_shadow[si]);
FMOE_SWE(torch_stream, evt_shadow[si]); FMOE_SWE(smgr->torchStream(), evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0 stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i]; long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i]; long micro_batch_size = local_expert_count[i];
...@@ -282,8 +281,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -282,8 +281,7 @@ void fmoe_cuda_fused_backward_impl(
long d_model, long d_model,
long num_expert, long rank, long world_size, long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) { long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); smgr->syncTorch();
cudaStreamSynchronize(torch_stream);
int *local_ptr = new int[num_expert * world_size + 1]; int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1];
...@@ -350,7 +348,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -350,7 +348,7 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n // C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]); FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]); FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base]; long offset = global_ptr[ei * world_size + from_base];
...@@ -362,7 +360,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -362,7 +360,7 @@ void fmoe_cuda_fused_backward_impl(
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
} }
cudaEventRecord(output_ready[step], smgr->stream(0)); cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], torch_stream); cudaEventRecord(output_torch_ready[step], smgr->torchStream());
} }
// Collect gradients for shadowed experts // Collect gradients for shadowed experts
...@@ -370,7 +368,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -370,7 +368,7 @@ void fmoe_cuda_fused_backward_impl(
if (stored_models[i]) { if (stored_models[i]) {
if (i / num_expert == rank) { if (i / num_expert == rank) {
FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]); FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]);
FMOE_SWE(torch_stream, evt_reduce[i % num_expert]); FMOE_SWE(smgr->torchStream(), evt_reduce[i % num_expert]);
set_grad_fn(si, i % num_expert); set_grad_fn(si, i % num_expert);
} }
++si; ++si;
......
...@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl( ...@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl(
ncclInt64, ncclInt64,
i, i,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
NCCL_SAFE_CALL(ncclRecv( NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i, global_expert_count + n_expert * i,
n_expert, n_expert,
ncclInt64, ncclInt64,
i, i,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
} }
torch::Tensor _expert_exchange( torch::Tensor _expert_exchange(
......
...@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl( ...@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl(
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
} }
if (global_expert_count[idx]) { if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv( NCCL_SAFE_CALL(ncclRecv(
...@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl( ...@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl(
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
recv_ptr += global_expert_count[idx]; recv_ptr += global_expert_count[idx];
} }
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
delete [] expert_ptr; delete [] expert_ptr;
smgr->sync(1);
} }
template<typename scalar_t> template<typename scalar_t>
...@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl( ...@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl(
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
send_ptr += global_expert_count[idx]; send_ptr += global_expert_count[idx];
} }
if (local_expert_count[idx]) { if (local_expert_count[idx]) {
...@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl( ...@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl(
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
} }
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
delete [] expert_ptr; delete [] expert_ptr;
smgr->sync(1);
} }
......
...@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl( ...@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl(
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
size_t numel = batch_size * topk; size_t numel = batch_size * topk;
assign_pos_kernel assign_pos_kernel
<<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>> <<<CEIL(numel, 256), 256, 0, smgr->torchStream()>>>
(cum_count, gate, pos, numel, topk); (cum_count, gate, pos, numel, topk);
smgr->sync(1);
} }
#define PERTHREAD_EXPERTS 256 #define PERTHREAD_EXPERTS 256
...@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl( ...@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl(
const size_t batch_size, const size_t n_expert, const size_t batch_size, const size_t n_expert,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
expert_count_kernel expert_count_kernel
<<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->stream(0)>>> <<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->torchStream()>>>
(gate_idx, expert_count, batch_size, n_expert); (gate_idx, expert_count, batch_size, n_expert);
smgr->sync(1);
} }
...@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl( ...@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl(
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = has_bias ? 1 : 0; scalar_t alpha = 1, beta = has_bias ? 1 : 0;
smgr->syncTorch();
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) { if (expert_count[i] == 0) {
continue; continue;
...@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl( ...@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl(
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
smgr->syncTorch();
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
// bias // bias
......
...@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) { ...@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS]; return this->streams[idx % SMGR_N_STREAMS];
} }
cudaStream_t CudaStreamManager::torchStream() {
return c10::cuda::getCurrentCUDAStream().stream();
}
cublasHandle_t CudaStreamManager::handle(size_t idx) { cublasHandle_t CudaStreamManager::handle(size_t idx) {
if (this->use_default) { if (this->use_default) {
return at::cuda::getCurrentCUDABlasHandle(); return at::cuda::getCurrentCUDABlasHandle();
...@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) { ...@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) {
} }
void CudaStreamManager::syncTorch() {
cudaStreamSynchronize(this->torchStream());
}
void CudaStreamManager::sync(int idx) { void CudaStreamManager::sync(int idx) {
if (this->use_default) { if (this->use_default) {
return; return;
......
...@@ -34,8 +34,10 @@ public: ...@@ -34,8 +34,10 @@ public:
void setup(int); void setup(int);
void sync(int=0); void sync(int=0);
void syncTorch();
void destroy(); void destroy();
cudaStream_t torchStream();
cudaStream_t stream(size_t=0); cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0); cublasHandle_t handle(size_t=0);
......
...@@ -37,7 +37,7 @@ class MoEForward(Function): ...@@ -37,7 +37,7 @@ class MoEForward(Function):
try: try:
# To skip torch autograd's version check. # To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64)) y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx)
except Exception as e: except Exception as e:
# Ignore the error and fall back for compatibility to older # Ignore the error and fall back for compatibility to older
# versions of PyTorch # versions of PyTorch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment