Commit ff28081c authored by zms1999's avatar zms1999
Browse files

[BUG FIX] wait torch stream

parent c8633740
......@@ -139,9 +139,11 @@ void fmoe_cuda_fused_forward_impl(
cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) {
cudaEventCreate(input_ready + i);
cudaEventCreate(output_ready + i);
cudaEventCreate(output_torch_ready + i);
}
// S_0 ... S_n
......@@ -200,6 +202,7 @@ void fmoe_cuda_fused_forward_impl(
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], torch_stream);
}
// Compute over shadowed experts
......@@ -221,6 +224,7 @@ void fmoe_cuda_fused_forward_impl(
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
......@@ -246,12 +250,14 @@ void fmoe_cuda_fused_forward_impl(
for (long i = 0; i < n_groups; ++i) {
cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]);
cudaEventDestroy(output_torch_ready[i]);
}
for (unsigned i = 0; i < params.size(); ++i) {
cudaEventDestroy(evt_shadow[i]);
}
delete [] input_ready;
delete [] output_ready;
delete [] output_torch_ready;
}
......@@ -292,9 +298,11 @@ void fmoe_cuda_fused_backward_impl(
cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) {
cudaEventCreate(input_ready + i);
cudaEventCreate(output_ready + i);
cudaEventCreate(output_torch_ready + i);
}
// S_0 ... S_n
......@@ -352,6 +360,7 @@ void fmoe_cuda_fused_backward_impl(
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], torch_stream);
}
// Collect gradients for shadowed experts
......@@ -369,6 +378,7 @@ void fmoe_cuda_fused_backward_impl(
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
......@@ -396,9 +406,11 @@ void fmoe_cuda_fused_backward_impl(
for (long i = 0; i < n_groups; ++i) {
cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]);
cudaEventDestroy(output_torch_ready[i]);
}
delete [] input_ready;
delete [] output_ready;
delete [] output_torch_ready;
for (long i = 0; i < num_expert; ++i) {
if (stored_models[i + rank * num_expert]) {
cudaEventDestroy(evt_reduce[i]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment