Unverified Commit ec49a52b authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Dgrad ReduceScatter overlap fix (#1088)



* DGRAD-RS overlap bug fix

This PR fixes a bug in enabling DGRAD-RS overlap by adding the
layer to the correct method list. Previously, the RS-DGRAD overlap
layer was incorrectly added to pipeline method list even if
ring_exchange method is specified in config.
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Bug fix for ring_exchange ReduceScatter

ring_exchange RS uses main_stream for last GEMM chunk. But the
send/recv streams wait for stream_compute during last chunk.
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b4840386
......@@ -1205,11 +1205,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
if (i == _tp_size - 1) {
at::cuda::setCurrentCUDAStream(stream_main);
} else {
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
}
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb,
_ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms);
......@@ -1230,6 +1226,13 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
recv_rank, (cudaStream_t)_stream_recv);
}
}
at::cuda::setCurrentCUDAStream(stream_main);
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
......@@ -1248,12 +1251,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
_ub_comm->sms = ori_sms;
}
......
......@@ -340,7 +340,9 @@ def initialize_ub(
layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name)
methods["pipeline"].append(name)
methods["bulk"].remove(name)
new_method = ub_cfgs[name]["method"]
methods[new_method].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
ub_cfg = get_default_config(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment