Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
fc1b91c2
Unverified
Commit
fc1b91c2
authored
Mar 03, 2025
by
vasunvidia
Committed by
GitHub
Mar 03, 2025
Browse files
Launch GEMM on compute_stream which has low priority. (#1522)
Signed-off-by:
Vasudevan Rengasamy
<
vrengasamy@nvidia.com
>
parent
c5d6a069
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
2 deletions
+6
-2
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+6
-2
No files found.
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
fc1b91c2
...
@@ -262,6 +262,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
...
@@ -262,6 +262,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
// Catch up the default torch stream
// Catch up the default torch stream
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[
0
],
_start_comm
,
0
));
// Communication: AG and RS
// Communication: AG and RS
int
comm_elements
=
(
_ubuf
.
numel
()
/
2
)
*
_ubuf
.
element_size
();
// UBUF uses 2Byte element size
int
comm_elements
=
(
_ubuf
.
numel
()
/
2
)
*
_ubuf
.
element_size
();
// UBUF uses 2Byte element size
...
@@ -288,14 +289,17 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
...
@@ -288,14 +289,17 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
assert
(
pre_gelu_out
.
numel
()
==
0
);
assert
(
pre_gelu_out
.
numel
()
==
0
);
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
if
(
_comm_launch_event
)
if
(
_comm_launch_event
)
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
((
cudaStream_t
)
stream_
main
,
_comm_launch_event
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
((
cudaStream_t
)
_
stream_
compute
[
0
]
,
_comm_launch_event
,
0
));
nvte_cublas_gemm
(
A
.
data
(),
B
.
data
(),
D
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
nvte_cublas_gemm
(
A
.
data
(),
B
.
data
(),
D
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
grad
,
workspace
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
stream_
main
);
_
stream_
compute
[
0
]
);
_ub_comm
->
sms
=
ori_sms
;
_ub_comm
->
sms
=
ori_sms
;
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_compute
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
}
// CommOverlapBase::bulk_overlap
}
// CommOverlapBase::bulk_overlap
/*
/*
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment