Unverified Commit fc1b91c2 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Launch GEMM on compute_stream which has low priority. (#1522)


Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent c5d6a069
...@@ -262,6 +262,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -262,6 +262,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
// Catch up the default torch stream // Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0));
// Communication: AG and RS // Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
...@@ -288,14 +289,17 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -288,14 +289,17 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
if (_comm_launch_event) if (_comm_launch_event)
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _comm_launch_event, 0));
nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
stream_main); _stream_compute[0]);
_ub_comm->sms = ori_sms; _ub_comm->sms = ori_sms;
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::bulk_overlap } // CommOverlapBase::bulk_overlap
/* /*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment