".github/git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "296cd091bc609ff330b41b2acc0470b33161f5d5"
Unverified Commit ec49a52b authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Dgrad ReduceScatter overlap fix (#1088)



* DGRAD-RS overlap bug fix

This PR fixes a bug in enabling DGRAD-RS overlap by adding the
layer to the correct method list. Previously, the RS-DGRAD overlap
layer was incorrectly added to pipeline method list even if
ring_exchange method is specified in config.
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for ring_exchange ReduceScatter

ring_exchange RS uses main_stream for last GEMM chunk. But the
send/recv streams wait for stream_compute during last chunk.
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b4840386
...@@ -1205,11 +1205,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1205,11 +1205,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
torch::Tensor workspace_chunk = torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options()); {workspace_size_chunk}, workspace.options());
if (i == _tp_size - 1) { at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
at::cuda::setCurrentCUDAStream(stream_main);
} else {
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
}
te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb,
_ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms);
...@@ -1230,6 +1226,13 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1230,6 +1226,13 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
recv_rank, (cudaStream_t)_stream_recv); recv_rank, (cudaStream_t)_stream_recv);
} }
} }
at::cuda::setCurrentCUDAStream(stream_main);
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
...@@ -1248,12 +1251,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1248,12 +1251,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0); torch::sum_out(rs_output, reduce_buf, 0);
} }
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
_ub_comm->sms = ori_sms; _ub_comm->sms = ori_sms;
} }
......
...@@ -340,7 +340,9 @@ def initialize_ub( ...@@ -340,7 +340,9 @@ def initialize_ub(
layers_reduce_scatter_overlap.remove(wgrad_name) layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name) layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name) layers_reduce_scatter_overlap.append(name)
methods["pipeline"].append(name) methods["bulk"].remove(name)
new_method = ub_cfgs[name]["method"]
methods[new_method].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
ub_cfg = get_default_config(name) ub_cfg = get_default_config(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment