Unverified Commit 0757149d authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Fix ring_exchange RS to support CUDA graph capture (#811)


Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 816dd457
...@@ -99,7 +99,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -99,7 +99,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream; cudaStream_t stream;
cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1);
...@@ -596,7 +596,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -596,7 +596,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
ubuf_byte_ptr += ubuf_chunk_bytes; ubuf_byte_ptr += ubuf_chunk_bytes;
} }
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
for (int i = 0; i < std::min(num_max_streams, tp_size); i++) { for (int i = 0; i < std::min(num_max_streams, tp_size); i++) {
cudaStream_t stream; cudaStream_t stream;
cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1);
...@@ -691,7 +691,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -691,7 +691,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
...@@ -974,7 +974,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -974,7 +974,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
B_scale_inverse = B_scale_inverse[B_fp8_tensor]; B_scale_inverse = B_scale_inverse[B_fp8_tensor];
// Catch up the main stream // Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
...@@ -1055,8 +1055,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1055,8 +1055,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
B_scale_inverse = B_scale_inverse[B_fp8_tensor]; B_scale_inverse = B_scale_inverse[B_fp8_tensor];
// Catch up the main stream // Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (int i = 0; i < _stream_compute.size(); i++) { for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0));
} }
...@@ -1113,13 +1115,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1113,13 +1115,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0); torch::sum_out(rs_output, reduce_buf, 0);
} }
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
} }
/* /*
** Copy input to _ubufs[0] ** Copy input to _ubufs[0]
*/ */
void copy_input_to_ubuf(torch::Tensor input, bool chunk) { void copy_input_to_ubuf(torch::Tensor input, bool chunk) {
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
if (chunk) { if (chunk) {
// Copy input to the target ubuf chunk by rank offset // Copy input to the target ubuf chunk by rank offset
if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment