Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
3273bc20
Commit
3273bc20
authored
May 27, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.3' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
75e9ef24
521f8d3b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
39 additions
and
3 deletions
+39
-3
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+8
-1
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+6
-1
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+3
-0
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+18
-1
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+4
-0
No files found.
transformer_engine/pytorch/module/batched_linear.py
View file @
3273bc20
...
...
@@ -304,7 +304,14 @@ class _BatchLinear(torch.autograd.Function):
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
wgrad_list
=
[
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
w
.
device
)
for
w
in
weights
]
else
:
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
not
ctx
.
use_bias
or
(
ctx
.
wgrad_store
is
not
None
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
3273bc20
...
...
@@ -5,6 +5,7 @@
"""GroupedLinear API"""
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
import
warnings
import
os
import
functools
import
torch
...
...
@@ -394,7 +395,11 @@ class _GroupedLinear(torch.autograd.Function):
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
wgrad_list
=
[
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
)
for
w
in
weights
]
else
:
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
not
ctx
.
use_bias
or
(
ctx
.
wgrad_store
is
not
None
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
3273bc20
...
...
@@ -861,6 +861,9 @@ class _LayerNormLinear(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx
.
wgrad_store
.
put
([
ln_out_total
,
grad_output
],
wgrad_gemm
)
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
wgrad
=
torch
.
empty
(
weight
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
weight
.
device
)
else
:
# Call wgrad GEMM now
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
3273bc20
...
...
@@ -913,6 +913,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Choose whether to call wgrad GEMM now or delay
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
act_out
,
grad_output
],
fc2_wgrad_gemm
)
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
fc2_wgrad
=
torch
.
empty
(
origin_fc2_weight
.
shape
,
dtype
=
origin_fc2_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
# Call wgrad GEMM now
...
...
@@ -1166,7 +1174,16 @@ class _LayerNormMLP(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx
.
wgrad_store
.
put
([
ln_out_total
,
dact
],
fc1_wgrad_gemm
)
fc1_wgrad
=
None
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
fc1_wgrad
=
torch
.
empty
(
origin_fc1_weight
.
shape
,
dtype
=
origin_fc1_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
fc1_wgrad
=
None
if
fuse_gemm_and_bias_fc1_wgrad
:
fc1_bias_grad
=
None
else
:
...
...
transformer_engine/pytorch/module/linear.py
View file @
3273bc20
...
...
@@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Tuple, Union
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
warnings
import
os
import
torch
...
...
@@ -782,6 +783,9 @@ class _Linear(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx
.
wgrad_store
.
put
([
inputmat_total
,
grad_output
],
wgrad_gemm
)
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
wgrad
=
torch
.
empty
(
weight
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
weight
.
device
)
else
:
# Call wgrad GEMM now
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment