Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
11bc1775
Commit
11bc1775
authored
Aug 25, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.5' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
e12a1085
059d92e2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
19 deletions
+19
-19
tests/pytorch/test_int8_channelwise_gemm_exact.py
tests/pytorch/test_int8_channelwise_gemm_exact.py
+4
-4
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+14
-14
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+1
-1
No files found.
tests/pytorch/test_int8_channelwise_gemm_exact.py
View file @
11bc1775
...
@@ -756,7 +756,7 @@ for i in range(b):
...
@@ -756,7 +756,7 @@ for i in range(b):
tensorwise_dequantize
(
dy_scales
[
i
],
x_scales
[
i
],
dw_int32
,
dw_ref
[
i
])
tensorwise_dequantize
(
dy_scales
[
i
],
x_scales
[
i
],
dw_int32
,
dw_ref
[
i
])
else
:
else
:
assert
False
assert
False
dw_ref_tensor
=
torch
.
stack
(
dw_ref
).
contiguous
()
dw_ref_tensor
=
torch
.
stack
(
dw_ref
).
contiguous
()
.
view
(
-
1
,
dw_ref
[
0
].
size
(
-
1
))
# print("dw_ref_tensor: ", dw_ref_tensor)
# print("dw_ref_tensor: ", dw_ref_tensor)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -771,13 +771,13 @@ dw_tensor = torch.stack(dw).contiguous()
...
@@ -771,13 +771,13 @@ dw_tensor = torch.stack(dw).contiguous()
out_dtype
=
torch
.
bfloat16
out_dtype
=
torch
.
bfloat16
dw_tensor
=
tex
.
tensorwise_int8_batchgemm
(
dw_tensor
=
tex
.
tensorwise_int8_batchgemm
(
x_int8_tensor
.
view
(
-
1
,
x_int8
.
size
(
-
1
)),
x_int8_tensor
.
view
(
-
1
,
x_int8
_tensor
.
size
(
-
1
)),
transa
,
transa
,
dy_int8_tensor
.
view
(
-
1
,
dy_int8
.
size
(
-
1
)),
dy_int8_tensor
.
view
(
-
1
,
dy_int8
_tensor
.
size
(
-
1
)),
transb
,
transb
,
x_scales_tensor
,
x_scales_tensor
,
dy_scales_tensor
,
dy_scales_tensor
,
dw_tensor
.
view
(
-
1
,
dw
.
size
(
-
1
)),
dw_tensor
.
view
(
-
1
,
dw
_tensor
.
size
(
-
1
)),
b
,
b
,
out_quantizer
,
out_quantizer
,
TE_DType
[
out_dtype
],
TE_DType
[
out_dtype
],
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
11bc1775
...
@@ -499,11 +499,11 @@ def general_grouped_gemm(
...
@@ -499,11 +499,11 @@ def general_grouped_gemm(
out
[
0
],
out
[
0
],
num_gemms
,
num_gemms
,
None
,
None
,
TE_DType
[
out_dtype
]
,
out_dtype
,
bias
[
0
]
,
None
,
bias_dtype
,
bias_dtype
,
gelu
,
gelu
,
gelu_input
[
0
]
,
None
,
grad
,
# grad
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
workspaces
[
0
].
shape
[
0
],
...
@@ -534,11 +534,11 @@ def general_grouped_gemm(
...
@@ -534,11 +534,11 @@ def general_grouped_gemm(
out
[
0
],
out
[
0
],
num_gemms
,
num_gemms
,
None
,
None
,
TE_DType
[
out_dtype
]
,
out_dtype
,
bias
[
0
]
,
None
,
bias_dtype
,
bias_dtype
,
gelu
,
gelu
,
gelu_input
[
0
]
,
None
,
grad
,
# grad
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
workspaces
[
0
].
shape
[
0
],
...
@@ -571,10 +571,10 @@ def general_grouped_gemm(
...
@@ -571,10 +571,10 @@ def general_grouped_gemm(
num_gemms
,
num_gemms
,
None
,
None
,
TE_DType
[
out_dtype
],
TE_DType
[
out_dtype
],
bias
[
0
]
,
None
,
bias_dtype
,
bias_dtype
,
gelu
,
gelu
,
gelu_input
[
0
]
,
None
,
grad
,
# grad
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
workspaces
[
0
].
shape
[
0
],
...
@@ -623,10 +623,10 @@ def general_grouped_gemm(
...
@@ -623,10 +623,10 @@ def general_grouped_gemm(
num_gemms
,
num_gemms
,
None
,
None
,
TE_DType
[
torch
.
int32
],
TE_DType
[
torch
.
int32
],
bias
[
0
]
,
None
,
bias_dtype
,
bias_dtype
,
gelu
,
gelu
,
gelu_input
[
0
]
,
None
,
grad
,
# grad
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
workspaces
[
0
].
shape
[
0
],
...
@@ -671,10 +671,10 @@ def general_grouped_gemm(
...
@@ -671,10 +671,10 @@ def general_grouped_gemm(
num_gemms
,
num_gemms
,
None
,
None
,
TE_DType
[
torch
.
int32
],
TE_DType
[
torch
.
int32
],
bias
[
0
]
,
None
,
bias_dtype
,
bias_dtype
,
gelu
,
gelu
,
gelu_input
[
0
]
,
None
,
grad
,
# grad
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
workspaces
[
0
].
shape
[
0
],
...
@@ -718,10 +718,10 @@ def general_grouped_gemm(
...
@@ -718,10 +718,10 @@ def general_grouped_gemm(
num_gemms
,
num_gemms
,
None
,
None
,
TE_DType
[
torch
.
int32
],
TE_DType
[
torch
.
int32
],
bias
[
0
]
,
None
,
bias_dtype
,
bias_dtype
,
gelu
,
gelu
,
gelu_input
[
0
]
,
None
,
grad
,
# grad
grad
,
# grad
workspaces
[
0
],
workspaces
[
0
],
workspaces
[
0
].
shape
[
0
],
workspaces
[
0
].
shape
[
0
],
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
11bc1775
...
@@ -937,7 +937,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -937,7 +937,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
)
dgrad
=
dgrad
.
reshape
(
inputmat
.
size
())
dgrad
=
dgrad
.
reshape
(
inputmat
.
size
())
elif
ctx
.
normalization
==
"RMSNorm"
:
elif
ctx
.
normalization
==
"RMSNorm"
:
if
enable_lightop
:
if
enable_lightop
and
(
rsigma
is
torch
.
bfloat16
or
rsigma
is
torch
.
float16
)
:
dgrad
,
dgamma
=
rmsnorm_backward
(
dgrad
,
inputmat
,
rsigma
,
ln_weight
)
dgrad
,
dgamma
=
rmsnorm_backward
(
dgrad
,
inputmat
,
rsigma
,
ln_weight
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment