Commit 2bd187cb authored by zms1999's avatar zms1999
Browse files

[BUG FIX] sync torch stream before nccl send/recv

parent ff28081c
...@@ -123,6 +123,7 @@ void fmoe_cuda_fused_forward_impl( ...@@ -123,6 +123,7 @@ void fmoe_cuda_fused_forward_impl(
long num_expert, long rank, long world_size, long expert_size, long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) { long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaStreamSynchronize(torch_stream);
int *local_ptr = new int[num_expert * world_size + 1]; int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1];
...@@ -282,6 +283,7 @@ void fmoe_cuda_fused_backward_impl( ...@@ -282,6 +283,7 @@ void fmoe_cuda_fused_backward_impl(
long num_expert, long rank, long world_size, long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) { long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaStreamSynchronize(torch_stream);
int *local_ptr = new int[num_expert * world_size + 1]; int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment