Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
d8041744
Commit
d8041744
authored
Aug 05, 2025
by
yuguo
Browse files
[DCU] fix all gather usage
parent
a397dcb7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
4 deletions
+7
-4
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+2
-2
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+5
-2
No files found.
transformer_engine/pytorch/distributed.py
View file @
d8041744
...
@@ -1004,10 +1004,10 @@ def _post_process_fp8_blockwise_gather(
...
@@ -1004,10 +1004,10 @@ def _post_process_fp8_blockwise_gather(
return
out
return
out
needs_columnwise_data_transpose
=
(
needs_columnwise_data_transpose
=
(
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
is_non_tn_fp8_gemm_supported
()
quantizer
is
not
None
and
quantizer
.
columnwise_usage
and
not
is_non_tn_fp8_gemm_supported
(
is_blockwise
=
True
)
)
)
need_rowwise_scale_transpose
=
(
need_rowwise_scale_transpose
=
(
quantizer
is
not
None
and
quantizer
.
rowwise_usage
and
not
is_non_tn_fp8_gemm_supported
()
quantizer
is
not
None
and
quantizer
.
rowwise_usage
and
not
is_non_tn_fp8_gemm_supported
(
is_blockwise
=
True
)
)
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
...
...
transformer_engine/pytorch/utils.py
View file @
d8041744
...
@@ -488,12 +488,15 @@ def is_bf16_compatible() -> None:
...
@@ -488,12 +488,15 @@ def is_bf16_compatible() -> None:
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
is_non_tn_fp8_gemm_supported
()
->
bool
:
def
is_non_tn_fp8_gemm_supported
(
is_blockwise
:
Optional
[
bool
]
=
False
)
->
bool
:
"""Checks whether the device supports
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
non-TN layouts for FP8 GEMMs.
"""
"""
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
return
True
if
is_blockwise
:
return
False
else
:
return
True
device_capability
=
torch
.
cuda
.
get_device_capability
()
device_capability
=
torch
.
cuda
.
get_device_capability
()
return
(
10
,
0
)
<=
device_capability
<
(
12
,
0
)
or
device_capability
>=
(
13
,
0
)
return
(
10
,
0
)
<=
device_capability
<
(
12
,
0
)
or
device_capability
>=
(
13
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment