Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
162e32d4
Commit
162e32d4
authored
Sep 24, 2025
by
dongcl
Browse files
support activation offloading
parent
8aca187f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
88 additions
and
11 deletions
+88
-11
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+1
-1
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+32
-5
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+27
-5
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+28
-0
No files found.
transformer_engine/pytorch/module/grouped_linear.py
View file @
162e32d4
...
...
@@ -298,8 +298,8 @@ class _GroupedLinear(torch.autograd.Function):
if
(
ctx
.
cpu_offloading
or
ctx
.
offload_activation
)
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
weights
[
i
].
main_grad
=
main_grads
[
i
]
if
ctx
.
offload_activation
and
weights
[
0
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
162e32d4
...
...
@@ -38,6 +38,7 @@ from ..utils import (
nvtx_range_push
,
requires_grad
,
needs_quantized_gemm
,
get_activation_offloading
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
...
...
@@ -431,10 +432,33 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
offload_activation
=
False
if
get_activation_offloading
():
offload_activation
=
True
if
not
inputmat
.
is_contiguous
():
inputmat
=
inputmat
.
contiguous
()
inputmat
.
offloading_activation
=
True
ctx
.
offload_activation
=
offload_activation
if
offload_activation
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use offload_activation and cpu_offloading at the same time."
)
if
offload_activation
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
:
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
weight
.
grad_added_to_main_grad
=
True
ctx
.
weight_object
=
weight
else
:
ctx
.
has_grad_added_to_main_grad
=
False
if
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
ctx
.
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
ctx
.
has_
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
grad_added_to_main_grad
:
if
ctx
.
has_
grad_added_to_main_grad
:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
...
...
@@ -567,9 +591,11 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if
ctx
.
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
grad_added_to_main_grad
:
if
ctx
.
cpu_offloading
or
ctx
.
offload_activation
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_
grad_added_to_main_grad
:
origin_weight
=
ctx
.
weight_object
if
ctx
.
offload_activation
:
origin_weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
origin_weight
.
main_grad
=
main_grad
...
...
@@ -949,7 +975,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
dgrad
=
dgrad
.
reshape
(
inputmat
.
size
())
elif
ctx
.
normalization
==
"RMSNorm"
:
if
enable_lightop
and
(
rsigma
.
dtype
is
torch
.
bfloat16
or
rsigma
.
dtype
is
torch
.
float16
):
if
enable_lightop
and
(
rsigma
is
torch
.
bfloat16
or
rsigma
is
torch
.
float16
):
dgrad
,
dgamma
=
rmsnorm_backward
(
dgrad
,
inputmat
,
rsigma
,
ln_weight
)
else
:
...
...
@@ -1546,6 +1572,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
else
:
fwd_fn
=
_LayerNormLinear
.
forward
args
=
[
None
]
args
+=
(
inp
,
self
.
layer_norm_weight
,
...
...
transformer_engine/pytorch/module/linear.py
View file @
162e32d4
...
...
@@ -38,6 +38,7 @@ from ..utils import (
assert_dim_for_fp8_exec
,
nvtx_range_pop
,
nvtx_range_push
,
get_activation_offloading
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
...
...
@@ -396,10 +397,30 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
offload_activation
=
False
if
get_activation_offloading
():
offload_activation
=
True
if
not
saved_inputmat
.
is_contiguous
():
saved_inputmat
=
saved_inputmat
.
contiguous
()
saved_inputmat
.
offloading_activation
=
True
ctx
.
offload_activation
=
offload_activation
if
offload_activation
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use offload_activation and cpu_offloading at the same time."
)
if
offload_activation
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
:
ctx
.
has_grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
has_grad_added_to_main_grad
:
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
ctx
.
weight_object
=
weight
weight
.
grad_added_to_main_grad
=
True
if
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
ctx
.
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
ctx
.
has_
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
grad_added_to_main_grad
:
if
ctx
.
has_
grad_added_to_main_grad
:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
...
...
@@ -482,7 +503,6 @@ class _Linear(torch.autograd.Function):
inputmat
,
weight_fp8
,
weight
,
bias
=
(
# pylint: disable=unbalanced-tuple-unpacking
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx
.
tensor_objects
=
None
...
...
@@ -494,9 +514,11 @@ class _Linear(torch.autograd.Function):
else
None
)
if
ctx
.
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
grad_added_to_main_grad
:
if
ctx
.
cpu_offloading
or
ctx
.
offload_activation
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_
grad_added_to_main_grad
:
weight
=
ctx
.
weight_object
if
ctx
.
offload_activation
:
weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
weight
.
main_grad
=
main_grad
...
...
transformer_engine/pytorch/utils.py
View file @
162e32d4
...
...
@@ -14,6 +14,34 @@ import transformer_engine.pytorch.cpp_extensions as ext
from
.
import
torch_version
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
ActivationOffloadEnabled
=
False
def
get_activation_offloading
():
global
ActivationOffloadEnabled
return
ActivationOffloadEnabled
def
set_activation_offloading
(
activation_offloading
):
global
ActivationOffloadEnabled
ActivationOffloadEnabled
=
activation_offloading
class
ActivationOffloadContextManager
:
"""A reusable context manager for switch ActivationOffloadEnabled"""
def
__init__
(
self
,
activation_offloading
):
self
.
activation_offloading
=
activation_offloading
def
__enter__
(
self
):
self
.
origin_cpu_offloading
=
get_activation_offloading
()
set_activation_offloading
(
self
.
activation_offloading
)
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
set_activation_offloading
(
self
.
origin_cpu_offloading
)
def
requires_grad
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
"""Check if any of the given tensors require gradient."""
for
tensor
in
tensors
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment