Commit c8633740 authored by zms1999's avatar zms1999
Browse files

[BUG FIX] make smart scheduling great again, fix bugs in streams management

parent c1c19f3e
...@@ -157,11 +157,11 @@ void fmoe_cuda_fused_forward_impl( ...@@ -157,11 +157,11 @@ void fmoe_cuda_fused_forward_impl(
local_expert_count[idx_send] * !stored_models[idx_send], rank_send, local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_input_buf + global_ptr[gidx_recv] * d_model, global_input_buf + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv, global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm); d_model, smgr->stream(num_expert), smgr->ncclcomm);
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
cudaEventRecord(input_ready[step], smgr->stream(0)); cudaEventRecord(input_ready[step], smgr->stream(num_expert));
} }
// Broadcast shadowed experts // Broadcast shadowed experts
...@@ -173,21 +173,22 @@ void fmoe_cuda_fused_forward_impl( ...@@ -173,21 +173,22 @@ void fmoe_cuda_fused_forward_impl(
if (stored_models[i]) { if (stored_models[i]) {
if (i / num_expert == rank) { if (i / num_expert == rank) {
cudaEventCreate(&evt_get); cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream); cudaEventRecord(evt_get, smgr->stream(0));
FMOE_SWE(smgr->stream(0), evt_get); FMOE_SWE(smgr->stream(num_expert), evt_get);
cudaEventDestroy(evt_get); cudaEventDestroy(evt_get);
} }
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(), NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
expert_size * sizeof(scalar_t), ncclChar, expert_size * sizeof(scalar_t), ncclChar,
i / num_expert, smgr->ncclcomm, smgr->stream(0))); i / num_expert, smgr->ncclcomm, smgr->stream(num_expert)));
cudaEventCreate(evt_shadow + si); cudaEventCreate(evt_shadow + si);
cudaEventRecord(evt_shadow[si], smgr->stream(0)); cudaEventRecord(evt_shadow[si], smgr->stream(num_expert));
++si; ++si;
} }
} }
// C_0 ... C_n // C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]); FMOE_SWE(torch_stream, input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
...@@ -198,12 +199,13 @@ void fmoe_cuda_fused_forward_impl( ...@@ -198,12 +199,13 @@ void fmoe_cuda_fused_forward_impl(
global_input_buf, global_output_buf, global_input_buf, global_output_buf,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
} }
cudaEventRecord(output_ready[step], torch_stream); cudaEventRecord(output_ready[step], smgr->stream(0));
} }
// Compute over shadowed experts // Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) { for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) { if (stored_models[i]) {
FMOE_SWE(smgr->stream(0), evt_shadow[si]);
FMOE_SWE(torch_stream, evt_shadow[si]); FMOE_SWE(torch_stream, evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0 stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i]; long offset = local_ptr[i];
...@@ -218,7 +220,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -218,7 +220,7 @@ void fmoe_cuda_fused_forward_impl(
// R_0 ... R_n // R_0 ... R_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), output_ready[step]); FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart()); NCCL_SAFE_CALL(ncclGroupStart());
...@@ -230,12 +232,12 @@ void fmoe_cuda_fused_forward_impl( ...@@ -230,12 +232,12 @@ void fmoe_cuda_fused_forward_impl(
global_expert_count[idx_send] * !stored_models[idx_self], rank_send, global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
output_buf + local_ptr[idx_recv] * d_model, output_buf + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv, local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm); d_model, smgr->stream(num_expert), smgr->ncclcomm);
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
} }
smgr->sync(1); smgr->sync(num_expert + 1);
delete [] local_ptr; delete [] local_ptr;
delete [] global_ptr; delete [] global_ptr;
...@@ -308,11 +310,11 @@ void fmoe_cuda_fused_backward_impl( ...@@ -308,11 +310,11 @@ void fmoe_cuda_fused_backward_impl(
local_expert_count[idx_send] * !stored_models[idx_send], rank_send, local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_grad_out + global_ptr[gidx_recv] * d_model, global_grad_out + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv, global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm); d_model, smgr->stream(num_expert), smgr->ncclcomm);
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
cudaEventRecord(input_ready[step], smgr->stream(0)); cudaEventRecord(input_ready[step], smgr->stream(num_expert));
} }
// Shadowed experts backward and reduce // Shadowed experts backward and reduce
...@@ -328,7 +330,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -328,7 +330,7 @@ void fmoe_cuda_fused_backward_impl(
collect_fn(si, i / num_expert, 0); collect_fn(si, i / num_expert, 0);
if (i / num_expert == rank) { if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert); cudaEventCreate(evt_reduce + i % num_expert);
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0)); cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(num_expert));
} }
++si; ++si;
} }
...@@ -337,6 +339,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -337,6 +339,7 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n // C_0 ... C_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]); FMOE_SWE(torch_stream, input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
...@@ -348,13 +351,14 @@ void fmoe_cuda_fused_backward_impl( ...@@ -348,13 +351,14 @@ void fmoe_cuda_fused_backward_impl(
global_grad_out, global_grad_in, global_grad_out, global_grad_in,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
} }
cudaEventRecord(output_ready[step], torch_stream); cudaEventRecord(output_ready[step], smgr->stream(0));
} }
// Collect gradients for shadowed experts // Collect gradients for shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) { for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) { if (stored_models[i]) {
if (i / num_expert == rank) { if (i / num_expert == rank) {
FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]);
FMOE_SWE(torch_stream, evt_reduce[i % num_expert]); FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
set_grad_fn(si, i % num_expert); set_grad_fn(si, i % num_expert);
} }
...@@ -364,7 +368,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -364,7 +368,7 @@ void fmoe_cuda_fused_backward_impl(
// R_0 ... R_n // R_0 ... R_n
for (long step = 0; step < n_groups; ++step) { for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), output_ready[step]); FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) { for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step); GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart()); NCCL_SAFE_CALL(ncclGroupStart());
...@@ -376,13 +380,13 @@ void fmoe_cuda_fused_backward_impl( ...@@ -376,13 +380,13 @@ void fmoe_cuda_fused_backward_impl(
global_expert_count[idx_send] * !stored_models[idx_self], rank_send, global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
grad_in + local_ptr[idx_recv] * d_model, grad_in + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv, local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm); d_model, smgr->stream(num_expert), smgr->ncclcomm);
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
} }
smgr->sync(1); smgr->sync(num_expert + 1);
checkCudaErrors(cudaGetLastError()); checkCudaErrors(cudaGetLastError());
delete [] local_ptr; delete [] local_ptr;
......
...@@ -45,7 +45,11 @@ void CudaStreamManager::setup(const int device) { ...@@ -45,7 +45,11 @@ void CudaStreamManager::setup(const int device) {
streams = new cudaStream_t[SMGR_N_STREAMS]; streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS]; handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) { for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i)); // SHOULD NOT USE: cudaStreamCreate(...)
// more details in
// https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
checkCudaErrors(cudaStreamCreateWithFlags(streams + i,
cudaStreamNonBlocking));
checkCudaErrors(cublasCreate(handles + i)); checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]); cublasSetStream(handles[i], streams[i]);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment