Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ecdd8251
Commit
ecdd8251
authored
Jun 13, 2025
by
yuguo
Browse files
[DCU] fix blockwise int8 train issues in megatron
parent
7f946529
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
12 deletions
+8
-12
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+3
-12
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+5
-0
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
ecdd8251
...
...
@@ -82,8 +82,6 @@ def general_gemm(
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
,
None
elif
layout
==
"NN"
:
...
...
@@ -103,8 +101,6 @@ def general_gemm(
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
,
None
elif
layout
==
"NT"
:
...
...
@@ -124,8 +120,6 @@ def general_gemm(
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
,
None
else
:
...
...
@@ -234,9 +228,8 @@ def general_grouped_gemm(
)
if
accumulate
:
assert
out
is
not
None
out
=
torch
.
stack
(
out
).
contiguous
()
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
elif
layout
==
"NN"
:
...
...
@@ -255,9 +248,8 @@ def general_grouped_gemm(
)
if
accumulate
:
assert
out
is
not
None
out
=
torch
.
stack
(
out
).
contiguous
()
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
elif
layout
==
"NT"
:
...
...
@@ -276,9 +268,8 @@ def general_grouped_gemm(
)
if
accumulate
:
assert
out
is
not
None
out
=
torch
.
stack
(
out
).
contiguous
()
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
,
None
,
None
else
:
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
ecdd8251
...
...
@@ -324,6 +324,8 @@ def w8a8_block_int8_matmul_wgrad(
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
B
=
B
.
view
(
B
.
size
(
0
),
-
1
)
assert
A
.
ndim
==
2
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
# print(f"A.shape[:-1] : {A.shape[:-1]}, As.shape[:-1]: {As.shape[:-1]}")
...
...
@@ -451,6 +453,9 @@ def w8a8_block_int8_matmul_wgrad_batched(
As
=
torch
.
stack
(
As_list
).
contiguous
()
Bs
=
torch
.
stack
(
Bs_list
).
contiguous
()
B_new_shape
=
B
.
size
()[:
2
]
+
(
-
1
,)
B
=
B
.
view
(
*
B_new_shape
)
assert
A
.
ndim
==
3
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
//
A
.
shape
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment