Unverified Commit 3a41edb8 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #172 from laekov/smgr_bug

[BUG FIX] Fix bugs in stream manager.
parents c1c19f3e 1f82fb16
...@@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once(
} }
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker); long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr); fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
smgr->syncTorch();
long *gec = _d2h(d_gec, n_worker); long *gec = _d2h(d_gec, n_worker);
/* Limit number of incoming samples */ /* Limit number of incoming samples */
...@@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once( ...@@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once(
/* Send limit information back */ /* Send limit information back */
_h2d(gec, d_gec, n_worker); _h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr); fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
smgr->syncTorch();
_d2h(d_lec, lec, n_worker); _d2h(d_lec, lec, n_worker);
auto d_dropcount = _h2d(drop_count, n_worker); auto d_dropcount = _h2d(drop_count, n_worker);
ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum, ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
smgr->ncclcomm, smgr->stream()); smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
_d2h(d_dropcount, drop_count, n_worker); _d2h(d_dropcount, drop_count, n_worker);
auto d_gcap = _cudamalloc<long>(n_worker); auto d_gcap = _cudamalloc<long>(n_worker);
_h2d(&cap, d_gcap + rank, 1); _h2d(&cap, d_gcap + rank, 1);
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64, ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->stream()); smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
auto gcap = _d2h(d_gcap, n_worker); auto gcap = _d2h(d_gcap, n_worker);
/* Re-assign and update counters */ /* Re-assign and update counters */
......
...@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap, ...@@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(n_worker, 1024), n_expert); dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024); dim3 block_dim(1024);
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>( limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
ec, cap, eca, n_expert, n_worker); ec, cap, eca, n_expert, n_worker);
smgr->sync(1);
} }
__global__ __global__
...@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx, ...@@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024)); dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024); dim3 block_dim(1024);
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>( prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker
); );
smgr->sync(1);
} }
...@@ -44,10 +44,9 @@ void _reduce_grad( ...@@ -44,10 +44,9 @@ void _reduce_grad(
long expert_size) { long expert_size) {
auto smgr = getCudaStreamManager(t.device().index()); auto smgr = getCudaStreamManager(t.device().index());
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEvent_t evt_stash; cudaEvent_t evt_stash;
cudaEventCreate(&evt_stash); cudaEventCreate(&evt_stash);
cudaEventRecord(evt_stash, torch_stream); cudaEventRecord(evt_stash, smgr->torchStream());
FMOE_SWE(smgr->stream(0), evt_stash); FMOE_SWE(smgr->stream(0), evt_stash);
cudaEventDestroy(evt_stash); cudaEventDestroy(evt_stash);
......
...@@ -122,7 +122,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -122,7 +122,7 @@ void fmoe_cuda_fused_forward_impl(
long d_model, long d_model,
long num_expert, long rank, long world_size, long expert_size, long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) { long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); smgr->syncTorch();
int *local_ptr = new int[num_expert * world_size + 1]; int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1];
...@@ -139,9 +139,11 @@ void fmoe_cuda_fused_forward_impl( ...@@ -139,9 +139,11 @@ void fmoe_cuda_fused_forward_impl(
cudaEvent_t *input_ready = new cudaEvent_t[n_groups]; cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_ready = new cudaEvent_t[n_groups]; cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) { for (long i = 0; i < n_groups; ++i) {
cudaEventCreate(input_ready + i); cudaEventCreate(input_ready + i);
cudaEventCreate(output_ready + i); cudaEventCreate(output_ready + i);
cudaEventCreate(output_torch_ready + i);
} }
// S_0 ... S_n // S_0 ... S_n
...@@ -157,11 +159,11 @@ void fmoe_cuda_fused_forward_impl( ...@@ -157,11 +159,11 @@ void fmoe_cuda_fused_forward_impl(
local_expert_count[idx_send] * !stored_models[idx_send], rank_send, local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_input_buf + global_ptr[gidx_recv] * d_model, global_input_buf + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv, global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm); d_model, smgr->stream(num_expert), smgr->ncclcomm);
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
cudaEventRecord(input_ready[step], smgr->stream(0)); cudaEventRecord(input_ready[step], smgr->stream(num_expert));
} }
// Broadcast shadowed experts // Broadcast shadowed experts
...@@ -173,22 +175,23 @@ void fmoe_cuda_fused_forward_impl( ...@@ -173,22 +175,23 @@ void fmoe_cuda_fused_forward_impl(
if (stored_models[i]) { if (stored_models[i]) {
if (i / num_expert == rank) { if (i / num_expert == rank) {
cudaEventCreate(&evt_get); cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream); cudaEventRecord(evt_get, smgr->stream(0));
FMOE_SWE(smgr->stream(0), evt_get); FMOE_SWE(smgr->stream(num_expert), evt_get);
cudaEventDestroy(evt_get); cudaEventDestroy(evt_get);
} }
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(), NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
expert_size * sizeof(scalar_t), ncclChar, expert_size * sizeof(scalar_t), ncclChar,
i / num_expert, smgr->ncclcomm, smgr->stream(0))); i / num_expert, smgr->ncclcomm, smgr->stream(num_expert)));
cudaEventCreate(evt_shadow + si); cudaEventCreate(evt_shadow + si);
cudaEventRecord(evt_shadow[si], smgr->stream(0)); cudaEventRecord(evt_shadow[si], smgr->stream(num_expert));
++si; ++si;
} }
} }
// C_0 ... C_n // C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(torch_stream, input_ready[step]); FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base]; long offset = global_ptr[ei * world_size + from_base];
...@@ -198,13 +201,15 @@ void fmoe_cuda_fused_forward_impl( ...@@ -198,13 +201,15 @@ void fmoe_cuda_fused_forward_impl(
global_input_buf, global_output_buf, global_input_buf, global_output_buf,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
} }
cudaEventRecord(output_ready[step], torch_stream); cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], smgr->torchStream());
} }
// Compute over shadowed experts // Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) { for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) { if (stored_models[i]) {
FMOE_SWE(torch_stream, evt_shadow[si]); FMOE_SWE(smgr->stream(0), evt_shadow[si]);
FMOE_SWE(smgr->torchStream(), evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0 stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i]; long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i]; long micro_batch_size = local_expert_count[i];
...@@ -218,7 +223,8 @@ void fmoe_cuda_fused_forward_impl( ...@@ -218,7 +223,8 @@ void fmoe_cuda_fused_forward_impl(
// R_0 ... R_n // R_0 ... R_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), output_ready[step]); FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart()); NCCL_SAFE_CALL(ncclGroupStart());
...@@ -230,12 +236,12 @@ void fmoe_cuda_fused_forward_impl( ...@@ -230,12 +236,12 @@ void fmoe_cuda_fused_forward_impl(
global_expert_count[idx_send] * !stored_models[idx_self], rank_send, global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
output_buf + local_ptr[idx_recv] * d_model, output_buf + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv, local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm); d_model, smgr->stream(num_expert), smgr->ncclcomm);
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
} }
smgr->sync(1); smgr->sync(num_expert + 1);
delete [] local_ptr; delete [] local_ptr;
delete [] global_ptr; delete [] global_ptr;
...@@ -244,12 +250,14 @@ void fmoe_cuda_fused_forward_impl( ...@@ -244,12 +250,14 @@ void fmoe_cuda_fused_forward_impl(
for (long i = 0; i < n_groups; ++i) { for (long i = 0; i < n_groups; ++i) {
cudaEventDestroy(input_ready[i]); cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]); cudaEventDestroy(output_ready[i]);
cudaEventDestroy(output_torch_ready[i]);
} }
for (unsigned i = 0; i < params.size(); ++i) { for (unsigned i = 0; i < params.size(); ++i) {
cudaEventDestroy(evt_shadow[i]); cudaEventDestroy(evt_shadow[i]);
} }
delete [] input_ready; delete [] input_ready;
delete [] output_ready; delete [] output_ready;
delete [] output_torch_ready;
} }
...@@ -273,7 +281,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -273,7 +281,7 @@ void fmoe_cuda_fused_backward_impl(
long d_model, long d_model,
long num_expert, long rank, long world_size, long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) { long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); smgr->syncTorch();
int *local_ptr = new int[num_expert * world_size + 1]; int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1];
...@@ -290,9 +298,11 @@ void fmoe_cuda_fused_backward_impl( ...@@ -290,9 +298,11 @@ void fmoe_cuda_fused_backward_impl(
cudaEvent_t *input_ready = new cudaEvent_t[n_groups]; cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_ready = new cudaEvent_t[n_groups]; cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) { for (long i = 0; i < n_groups; ++i) {
cudaEventCreate(input_ready + i); cudaEventCreate(input_ready + i);
cudaEventCreate(output_ready + i); cudaEventCreate(output_ready + i);
cudaEventCreate(output_torch_ready + i);
} }
// S_0 ... S_n // S_0 ... S_n
...@@ -308,11 +318,11 @@ void fmoe_cuda_fused_backward_impl( ...@@ -308,11 +318,11 @@ void fmoe_cuda_fused_backward_impl(
local_expert_count[idx_send] * !stored_models[idx_send], rank_send, local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_grad_out + global_ptr[gidx_recv] * d_model, global_grad_out + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv, global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm); d_model, smgr->stream(num_expert), smgr->ncclcomm);
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
cudaEventRecord(input_ready[step], smgr->stream(0)); cudaEventRecord(input_ready[step], smgr->stream(num_expert));
} }
// Shadowed experts backward and reduce // Shadowed experts backward and reduce
...@@ -328,7 +338,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -328,7 +338,7 @@ void fmoe_cuda_fused_backward_impl(
collect_fn(si, i / num_expert, 0); collect_fn(si, i / num_expert, 0);
if (i / num_expert == rank) { if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert); cudaEventCreate(evt_reduce + i % num_expert);
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0)); cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(num_expert));
} }
++si; ++si;
} }
...@@ -337,7 +347,8 @@ void fmoe_cuda_fused_backward_impl( ...@@ -337,7 +347,8 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n // C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(torch_stream, input_ready[step]); FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base]; long offset = global_ptr[ei * world_size + from_base];
...@@ -348,14 +359,16 @@ void fmoe_cuda_fused_backward_impl( ...@@ -348,14 +359,16 @@ void fmoe_cuda_fused_backward_impl(
global_grad_out, global_grad_in, global_grad_out, global_grad_in,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
} }
cudaEventRecord(output_ready[step], torch_stream); cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], smgr->torchStream());
} }
// Collect gradients for shadowed experts // Collect gradients for shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) { for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) { if (stored_models[i]) {
if (i / num_expert == rank) { if (i / num_expert == rank) {
FMOE_SWE(torch_stream, evt_reduce[i % num_expert]); FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]);
FMOE_SWE(smgr->torchStream(), evt_reduce[i % num_expert]);
set_grad_fn(si, i % num_expert); set_grad_fn(si, i % num_expert);
} }
++si; ++si;
...@@ -364,7 +377,8 @@ void fmoe_cuda_fused_backward_impl( ...@@ -364,7 +377,8 @@ void fmoe_cuda_fused_backward_impl(
// R_0 ... R_n // R_0 ... R_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), output_ready[step]); FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart()); NCCL_SAFE_CALL(ncclGroupStart());
...@@ -376,13 +390,13 @@ void fmoe_cuda_fused_backward_impl( ...@@ -376,13 +390,13 @@ void fmoe_cuda_fused_backward_impl(
global_expert_count[idx_send] * !stored_models[idx_self], rank_send, global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
grad_in + local_ptr[idx_recv] * d_model, grad_in + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv, local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm); d_model, smgr->stream(num_expert), smgr->ncclcomm);
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
} }
smgr->sync(1); smgr->sync(num_expert + 1);
checkCudaErrors(cudaGetLastError()); checkCudaErrors(cudaGetLastError());
delete [] local_ptr; delete [] local_ptr;
...@@ -392,9 +406,11 @@ void fmoe_cuda_fused_backward_impl( ...@@ -392,9 +406,11 @@ void fmoe_cuda_fused_backward_impl(
for (long i = 0; i < n_groups; ++i) { for (long i = 0; i < n_groups; ++i) {
cudaEventDestroy(input_ready[i]); cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]); cudaEventDestroy(output_ready[i]);
cudaEventDestroy(output_torch_ready[i]);
} }
delete [] input_ready; delete [] input_ready;
delete [] output_ready; delete [] output_ready;
delete [] output_torch_ready;
for (long i = 0; i < num_expert; ++i) { for (long i = 0; i < num_expert; ++i) {
if (stored_models[i + rank * num_expert]) { if (stored_models[i + rank * num_expert]) {
cudaEventDestroy(evt_reduce[i]); cudaEventDestroy(evt_reduce[i]);
......
...@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl( ...@@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl(
ncclInt64, ncclInt64,
i, i,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
NCCL_SAFE_CALL(ncclRecv( NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i, global_expert_count + n_expert * i,
n_expert, n_expert,
ncclInt64, ncclInt64,
i, i,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
} }
torch::Tensor _expert_exchange( torch::Tensor _expert_exchange(
......
...@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl( ...@@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl(
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
} }
if (global_expert_count[idx]) { if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv( NCCL_SAFE_CALL(ncclRecv(
...@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl( ...@@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl(
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
recv_ptr += global_expert_count[idx]; recv_ptr += global_expert_count[idx];
} }
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
delete [] expert_ptr; delete [] expert_ptr;
smgr->sync(1);
} }
template<typename scalar_t> template<typename scalar_t>
...@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl( ...@@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl(
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
send_ptr += global_expert_count[idx]; send_ptr += global_expert_count[idx];
} }
if (local_expert_count[idx]) { if (local_expert_count[idx]) {
...@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl( ...@@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl(
ncclChar, ncclChar,
j, j,
smgr->ncclcomm, smgr->ncclcomm,
smgr->stream(0))); smgr->torchStream()));
} }
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
delete [] expert_ptr; delete [] expert_ptr;
smgr->sync(1);
} }
......
...@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl( ...@@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl(
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
size_t numel = batch_size * topk; size_t numel = batch_size * topk;
assign_pos_kernel assign_pos_kernel
<<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>> <<<CEIL(numel, 256), 256, 0, smgr->torchStream()>>>
(cum_count, gate, pos, numel, topk); (cum_count, gate, pos, numel, topk);
smgr->sync(1);
} }
#define PERTHREAD_EXPERTS 256 #define PERTHREAD_EXPERTS 256
...@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl( ...@@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl(
const size_t batch_size, const size_t n_expert, const size_t batch_size, const size_t n_expert,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
expert_count_kernel expert_count_kernel
<<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->stream(0)>>> <<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->torchStream()>>>
(gate_idx, expert_count, batch_size, n_expert); (gate_idx, expert_count, batch_size, n_expert);
smgr->sync(1);
} }
...@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl( ...@@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl(
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = has_bias ? 1 : 0; scalar_t alpha = 1, beta = has_bias ? 1 : 0;
smgr->syncTorch();
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) { if (expert_count[i] == 0) {
continue; continue;
...@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl( ...@@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl(
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert,
CudaStreamManager* smgr) { CudaStreamManager* smgr) {
smgr->syncTorch();
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
// bias // bias
......
...@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) { ...@@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) {
return this->streams[idx % SMGR_N_STREAMS]; return this->streams[idx % SMGR_N_STREAMS];
} }
cudaStream_t CudaStreamManager::torchStream() {
return c10::cuda::getCurrentCUDAStream().stream();
}
cublasHandle_t CudaStreamManager::handle(size_t idx) { cublasHandle_t CudaStreamManager::handle(size_t idx) {
if (this->use_default) { if (this->use_default) {
return at::cuda::getCurrentCUDABlasHandle(); return at::cuda::getCurrentCUDABlasHandle();
...@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) { ...@@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) {
} }
void CudaStreamManager::syncTorch() {
cudaStreamSynchronize(this->torchStream());
}
void CudaStreamManager::sync(int idx) { void CudaStreamManager::sync(int idx) {
if (this->use_default) { if (this->use_default) {
return; return;
...@@ -45,7 +53,11 @@ void CudaStreamManager::setup(const int device) { ...@@ -45,7 +53,11 @@ void CudaStreamManager::setup(const int device) {
streams = new cudaStream_t[SMGR_N_STREAMS]; streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS]; handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) { for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i)); // SHOULD NOT USE: cudaStreamCreate(...)
// more details in
// https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
checkCudaErrors(cudaStreamCreateWithFlags(streams + i,
cudaStreamNonBlocking));
checkCudaErrors(cublasCreate(handles + i)); checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]); cublasSetStream(handles[i], streams[i]);
} }
......
...@@ -34,8 +34,10 @@ public: ...@@ -34,8 +34,10 @@ public:
void setup(int); void setup(int);
void sync(int=0); void sync(int=0);
void syncTorch();
void destroy(); void destroy();
cudaStream_t torchStream();
cudaStream_t stream(size_t=0); cudaStream_t stream(size_t=0);
cublasHandle_t handle(size_t=0); cublasHandle_t handle(size_t=0);
......
...@@ -37,7 +37,7 @@ class MoEForward(Function): ...@@ -37,7 +37,7 @@ class MoEForward(Function):
try: try:
# To skip torch autograd's version check. # To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64)) y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx)
except Exception as e: except Exception as e:
# Ignore the error and fall back for compatibility to older # Ignore the error and fall back for compatibility to older
# versions of PyTorch # versions of PyTorch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment