Commit c8633740 authored by zms1999's avatar zms1999
Browse files

[BUG FIX] make smart scheduling great again, fix bugs in streams management

parent c1c19f3e
......@@ -157,11 +157,11 @@ void fmoe_cuda_fused_forward_impl(
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_input_buf + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
cudaEventRecord(input_ready[step], smgr->stream(0));
cudaEventRecord(input_ready[step], smgr->stream(num_expert));
}
// Broadcast shadowed experts
......@@ -173,21 +173,22 @@ void fmoe_cuda_fused_forward_impl(
if (stored_models[i]) {
if (i / num_expert == rank) {
cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream);
FMOE_SWE(smgr->stream(0), evt_get);
cudaEventRecord(evt_get, smgr->stream(0));
FMOE_SWE(smgr->stream(num_expert), evt_get);
cudaEventDestroy(evt_get);
}
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
expert_size * sizeof(scalar_t), ncclChar,
i / num_expert, smgr->ncclcomm, smgr->stream(0)));
i / num_expert, smgr->ncclcomm, smgr->stream(num_expert)));
cudaEventCreate(evt_shadow + si);
cudaEventRecord(evt_shadow[si], smgr->stream(0));
cudaEventRecord(evt_shadow[si], smgr->stream(num_expert));
++si;
}
}
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
......@@ -198,12 +199,13 @@ void fmoe_cuda_fused_forward_impl(
global_input_buf, global_output_buf,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
cudaEventRecord(output_ready[step], smgr->stream(0));
}
// Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
FMOE_SWE(smgr->stream(0), evt_shadow[si]);
FMOE_SWE(torch_stream, evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i];
......@@ -218,7 +220,7 @@ void fmoe_cuda_fused_forward_impl(
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
......@@ -230,12 +232,12 @@ void fmoe_cuda_fused_forward_impl(
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
output_buf + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
smgr->sync(1);
smgr->sync(num_expert + 1);
delete [] local_ptr;
delete [] global_ptr;
......@@ -308,11 +310,11 @@ void fmoe_cuda_fused_backward_impl(
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_grad_out + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
cudaEventRecord(input_ready[step], smgr->stream(0));
cudaEventRecord(input_ready[step], smgr->stream(num_expert));
}
// Shadowed experts backward and reduce
......@@ -328,7 +330,7 @@ void fmoe_cuda_fused_backward_impl(
collect_fn(si, i / num_expert, 0);
if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert);
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(num_expert));
}
++si;
}
......@@ -337,6 +339,7 @@ void fmoe_cuda_fused_backward_impl(
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(torch_stream, input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
......@@ -348,13 +351,14 @@ void fmoe_cuda_fused_backward_impl(
global_grad_out, global_grad_in,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
cudaEventRecord(output_ready[step], smgr->stream(0));
}
// Collect gradients for shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
if (i / num_expert == rank) {
FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]);
FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
set_grad_fn(si, i % num_expert);
}
......@@ -364,7 +368,7 @@ void fmoe_cuda_fused_backward_impl(
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
......@@ -376,13 +380,13 @@ void fmoe_cuda_fused_backward_impl(
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
grad_in + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
smgr->sync(1);
smgr->sync(num_expert + 1);
checkCudaErrors(cudaGetLastError());
delete [] local_ptr;
......
......@@ -45,7 +45,11 @@ void CudaStreamManager::setup(const int device) {
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(cudaStreamCreate(streams + i));
// SHOULD NOT USE: cudaStreamCreate(...)
// more details in
// https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
checkCudaErrors(cudaStreamCreateWithFlags(streams + i,
cudaStreamNonBlocking));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment