Commit 92d59fe4 authored by yuguo's avatar yuguo
Browse files

[DCU] remove cudaStreamSynchronize for tp overlap

parent aec86199
...@@ -265,7 +265,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -265,7 +265,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
bool accumulate, bool use_split_accumulator, bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output, CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) { cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
...@@ -315,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -315,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
} // CommOverlapBase::bulk_overlap } // CommOverlapBase::bulk_overlap
...@@ -424,7 +422,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -424,7 +422,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator, bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) { TensorWrapper &rs_output, cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
// Get GEMM dimensions // Get GEMM dimensions
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
...@@ -565,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -565,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
} // CommOverlapBase::split_overlap_rs } // CommOverlapBase::split_overlap_rs
/*************************************************************************************************** /***************************************************************************************************
...@@ -785,7 +781,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -785,7 +781,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy, bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) { cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
...@@ -943,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -943,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
} // CommOverlapP2PBase::split_overlap_ag } // CommOverlapP2PBase::split_overlap_ag
/* /*
...@@ -1018,7 +1012,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1018,7 +1012,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) { cudaStream_t stream_main) {
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
int ori_sms = _ub_comm->sms; int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce; _ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm; _ub_comm->sms = _num_comm_sm;
...@@ -1093,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1093,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream_main));
// Reduce GEMM output chunks // Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment