Commit e92773a3 authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.3' into 'main'

[DCU] cudaStreamSynchronize for tp gemm overlap

See merge request dcutoolkit/deeplearing/TransformerEngine!11
parents 9e6e1871 aec86199
...@@ -265,6 +265,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -265,6 +265,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
bool accumulate, bool use_split_accumulator, bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output, CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) { cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
...@@ -314,6 +315,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -314,6 +315,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
} // CommOverlapBase::bulk_overlap } // CommOverlapBase::bulk_overlap
...@@ -422,6 +424,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -422,6 +424,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator, bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) { TensorWrapper &rs_output, cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
// Get GEMM dimensions // Get GEMM dimensions
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
...@@ -562,6 +565,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -562,6 +565,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
} // CommOverlapBase::split_overlap_rs } // CommOverlapBase::split_overlap_rs
/*************************************************************************************************** /***************************************************************************************************
...@@ -781,6 +785,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -781,6 +785,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) { cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
...@@ -938,6 +943,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -938,6 +943,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
} // CommOverlapP2PBase::split_overlap_ag } // CommOverlapP2PBase::split_overlap_ag
/* /*
...@@ -1012,6 +1018,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1012,6 +1018,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) { cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
...@@ -1086,6 +1093,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1086,6 +1093,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
// Reduce GEMM output chunks // Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment