Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
521f8d3b
Commit
521f8d3b
authored
May 27, 2025
by
yuguo
Browse files
[DCU] combine 1f1b needs NVTE_OVERLAP_GRAD_REDUCE
parent
291fcf52
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
6 deletions
+41
-6
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+8
-1
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+6
-2
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+3
-0
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+20
-2
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+4
-1
No files found.
transformer_engine/pytorch/module/batched_linear.py
View file @
521f8d3b
...
...
@@ -304,7 +304,14 @@ class _BatchLinear(torch.autograd.Function):
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
wgrad_list
=
[
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
w
.
device
)
for
w
in
weights
]
else
:
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
not
ctx
.
use_bias
or
(
ctx
.
wgrad_store
is
not
None
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
521f8d3b
...
...
@@ -4,7 +4,7 @@
"""GroupedLinear API"""
from
typing
import
Union
,
Optional
,
Callable
,
Tuple
,
List
import
os
import
functools
import
torch
...
...
@@ -393,7 +393,11 @@ class _GroupedLinear(torch.autograd.Function):
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
wgrad_list
=
[
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
ctx
.
device
)
for
w
in
weights
]
else
:
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
if
not
ctx
.
use_bias
or
(
ctx
.
wgrad_store
is
not
None
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
521f8d3b
...
...
@@ -783,6 +783,9 @@ class _LayerNormLinear(torch.autograd.Function):
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
ln_out_total
,
grad_output
],
general_gemm_wgrad
)
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
wgrad
=
torch
.
empty
(
weight
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
weight
.
device
)
else
:
wgrad
,
grad_bias_
,
_
,
rs_out
=
general_gemm_wgrad
(
ln_out_total
,
grad_output
)
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
521f8d3b
...
...
@@ -843,7 +843,16 @@ class _LayerNormMLP(torch.autograd.Function):
)
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
act_out
,
grad_output
],
general_gemm_fc2_wgrad
)
fc2_wgrad
=
None
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
fc2_wgrad
=
torch
.
empty
(
origin_fc2_weight
.
shape
,
dtype
=
origin_fc2_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
fc2_wgrad
=
None
else
:
fc2_wgrad
,
fc2_bias_grad_
,
*
_
=
general_gemm_fc2_wgrad
(
act_out
,
...
...
@@ -1057,7 +1066,16 @@ class _LayerNormMLP(torch.autograd.Function):
)
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
ln_out_total
,
dact
],
general_gemm_fc1_wgrad
)
fc1_wgrad
=
None
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
fc1_wgrad
=
torch
.
empty
(
origin_fc1_weight
.
shape
,
dtype
=
origin_fc1_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
fc1_wgrad
=
None
if
fuse_gemm_and_bias_fc1_wgrad
:
fc1_bias_grad
=
None
else
:
...
...
transformer_engine/pytorch/module/linear.py
View file @
521f8d3b
...
...
@@ -6,7 +6,7 @@
from
typing
import
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
functools
import
reduce
from
operator
import
mul
as
multiply_op
import
os
import
functools
import
torch
...
...
@@ -699,6 +699,9 @@ class _Linear(torch.autograd.Function):
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
ctx
.
wgrad_store
.
put
([
inputmat_total
,
grad_output
],
general_gemm_wgrad
)
# overlap_grad_reduce, dongcl
if
int
(
os
.
getenv
(
"NVTE_OVERLAP_GRAD_REDUCE"
,
"0"
)):
wgrad
=
torch
.
empty
(
weight
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
weight
.
device
)
else
:
wgrad
,
grad_bias_
,
_
,
rs_out
=
general_gemm_wgrad
(
inputmat_total
,
grad_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment