Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
bd5a6e86
Commit
bd5a6e86
authored
Oct 30, 2025
by
tabuchixiangcai3
Browse files
[DCU]fix main_grad no exit
Signed-off-by:
Tangao
<
2205747538@qq.com
>
parent
29271c40
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
17 deletions
+2
-17
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+2
-17
No files found.
transformer_engine/pytorch/module/batched_linear.py
View file @
bd5a6e86
...
@@ -153,13 +153,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -153,13 +153,7 @@ class _BatchLinear(torch.autograd.Function):
if
cpu_offloading
:
if
cpu_offloading
:
if
fuse_wgrad_accumulation
:
if
fuse_wgrad_accumulation
:
for
w
in
weights
:
for
w
in
weights
:
if
getattr
(
w
,
"main_grad"
,
None
)
is
not
None
:
w
.
main_grad
.
weight_offloading
=
True
w
.
main_grad
.
weight_offloading
=
True
else
:
# Optional: log a warning if fuse requested but buffer missing
# logger = logging.getLogger("BatchLinear")
# logger.debug("fuse_wgrad_accumulation=True but weight.main_grad is missing; skipping weight_offloading for this weight.")
pass
for
w
in
weights
:
for
w
in
weights
:
w
.
weight_offloading
=
True
w
.
weight_offloading
=
True
for
t
in
saved_inputmats
:
for
t
in
saved_inputmats
:
...
@@ -168,7 +162,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -168,7 +162,7 @@ class _BatchLinear(torch.autograd.Function):
for
i
in
range
(
num_gemms
):
for
i
in
range
(
num_gemms
):
weights
[
i
].
offloading_activation
=
False
weights
[
i
].
offloading_activation
=
False
if
get
attr
(
weights
[
i
],
"
main_grad
"
,
None
)
is
not
None
:
if
fuse_wgrad_accumulation
and
has
attr
(
weights
[
i
],
'
main_grad
'
)
:
weights
[
i
].
main_grad
.
offloading_activation
=
False
weights
[
i
].
main_grad
.
offloading_activation
=
False
if
weights_fp8
[
i
]
is
not
None
:
if
weights_fp8
[
i
]
is
not
None
:
weights_fp8
[
i
].
offloading_activation
=
False
weights_fp8
[
i
].
offloading_activation
=
False
...
@@ -561,16 +555,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -561,16 +555,7 @@ class BatchedLinear(TransformerEngineBaseModule):
if
self
.
primary_weights_in_fp8
:
if
self
.
primary_weights_in_fp8
:
self
.
init_fp8_metadata
(
num_gemms
=
self
.
num_gemms
)
self
.
init_fp8_metadata
(
num_gemms
=
self
.
num_gemms
)
# Ensure main_grad buffers exist when fuse_wgrad_accumulation is enabled.
# Skip allocation under meta device (deferred init).
self
.
reset_parameters
(
defer_init
=
(
device
==
"meta"
))
self
.
reset_parameters
(
defer_init
=
(
device
==
"meta"
))
if
self
.
fuse_wgrad_accumulation
and
device
!=
"meta"
:
for
i
in
range
(
int
(
self
.
num_gemms
)):
w
=
getattr
(
self
,
f
"weight
{
i
}
"
)
if
getattr
(
w
,
"main_grad"
,
None
)
is
None
:
# use float32 buffer for main_grad (tests use float32)
w
.
main_grad
=
torch
.
empty_like
(
w
,
dtype
=
torch
.
float32
,
device
=
w
.
device
)
w
.
main_grad
.
zero_
()
# For RPL, bias has to be added after TP collectives
# For RPL, bias has to be added after TP collectives
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment