Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
aec86199
Commit
aec86199
authored
May 20, 2025
by
yuguo
Browse files
[DCU] cudaStreamSynchronize for tp gemm overlap
parent
460b006c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
0 deletions
+8
-0
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+8
-0
No files found.
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
aec86199
...
...
@@ -265,6 +265,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapType
comm_type
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
{
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream_main
));
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
...
...
@@ -314,6 +315,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_compute
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream_main
));
}
// CommOverlapBase::bulk_overlap
...
...
@@ -422,6 +424,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
TensorWrapper
&
pre_gelu_out
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
{
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream_main
));
// Get GEMM dimensions
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
...
...
@@ -562,6 +565,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream_main
));
}
// CommOverlapBase::split_overlap_rs
/***************************************************************************************************
...
...
@@ -781,6 +785,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
B_copy
,
cudaStream_t
stream_main
)
{
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream_main
));
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
...
...
@@ -938,6 +943,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream_main
));
}
// CommOverlapP2PBase::split_overlap_ag
/*
...
...
@@ -1012,6 +1018,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
{
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream_main
));
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
...
...
@@ -1086,6 +1093,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream_main
));
// Reduce GEMM output chunks
char
*
reduce_buf_ptr
=
reinterpret_cast
<
char
*>
(
_ubufs
[
_tp_size
-
1
].
dptr
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment