Commit e92773a3 authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.3' into 'main'

[DCU] cudaStreamSynchronize for tp gemm overlap

See merge request dcutoolkit/deeplearing/TransformerEngine!11
parents 9e6e1871 aec86199
......@@ -265,6 +265,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
......@@ -314,6 +315,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
} // CommOverlapBase::bulk_overlap
......@@ -422,6 +424,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
// Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
......@@ -562,6 +565,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
} // CommOverlapBase::split_overlap_rs
/***************************************************************************************************
......@@ -781,6 +785,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
......@@ -938,6 +943,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
} // CommOverlapP2PBase::split_overlap_ag
/*
......@@ -1012,6 +1018,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
......@@ -1086,6 +1093,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment