Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2389ed3f
Commit
2389ed3f
authored
Aug 27, 2025
by
yuguo
Browse files
Merge branch 'release_v2.7' of
https://github.com/NVIDIA/TransformerEngine
into release_v2.7
parents
87e3e56e
58c3ac80
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
11 deletions
+26
-11
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+24
-10
transformer_engine/pytorch/ops/linear.py
transformer_engine/pytorch/ops/linear.py
+2
-1
No files found.
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
2389ed3f
...
@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
...
@@ -14,11 +14,12 @@ from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_exter
from
...cpp_extensions
import
general_gemm
from
...cpp_extensions
import
general_gemm
from
...distributed
import
get_distributed_world_size
from
...distributed
import
get_distributed_world_size
from
...module.base
import
(
from
...module.base
import
(
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
fill_userbuffers_buffer_for_all_gather
,
fill_userbuffers_buffer_for_all_gather
,
get_dummy_wgrad
,
get_ub
,
get_ub
,
get_workspace
,
get_workspace
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
)
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
...
@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -513,20 +514,22 @@ class UserbuffersBackwardLinear(FusedOperation):
# Saved tensors from forward pass
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
weight_param
=
linear_op
.
weight
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
if
not
hasattr
(
weight
_param
,
"main_grad"
):
raise
RuntimeError
(
raise
RuntimeError
(
"BasicLinear op is configured with "
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
"but weight parameter does not have main_grad attribute"
)
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
grad_weight
=
weight
_param
.
main_grad
.
detach
()
else
:
else
:
accumulate_into_main_grad
=
False
accumulate_into_main_grad
=
False
...
@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -558,10 +561,21 @@ class UserbuffersBackwardLinear(FusedOperation):
# Clear input tensor if possible
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
clear_tensor_data
(
x_local
)
#
Return gradients
#
Megatron-LM wgrad fusion
grad_params
=
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
# Note: Return dummy tensor for grad weight if needed.
if
accumulate_into_main_grad
:
if
accumulate_into_main_grad
:
grad_weight
=
None
grad_weight
=
None
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"grad_added_to_main_grad"
):
weight_param
.
grad_added_to_main_grad
=
True
grad_weight
=
get_dummy_wgrad
(
list
(
weight_param
.
size
()),
weight_param
.
dtype
,
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
# Return gradients
grad_params
=
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
grad_params
[
self
.
_op_idxs
[
"linear"
]]
=
(
grad_weight
,)
grad_params
[
self
.
_op_idxs
[
"linear"
]]
=
(
grad_weight
,)
if
bias_op
is
not
None
:
if
bias_op
is
not
None
:
grad_params
[
self
.
_op_idxs
[
"bias"
]]
=
(
grad_bias
,)
grad_params
[
self
.
_op_idxs
[
"bias"
]]
=
(
grad_bias
,)
...
...
transformer_engine/pytorch/ops/linear.py
View file @
2389ed3f
...
@@ -54,7 +54,8 @@ class Linear(FusedOperation):
...
@@ -54,7 +54,8 @@ class Linear(FusedOperation):
weight's `main_grad` attribute instead of relying on PyTorch
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
there is no guarantee that `grad` will be set or be
meaningful.
meaningful. This is primarily intented to integrate with
Megatron-LM.
"""
"""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment