Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9711d439
Commit
9711d439
authored
Oct 17, 2025
by
dongcl
Browse files
Update activation offload code to align with the official version
parent
712d526a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
70 deletions
+105
-70
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+31
-22
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+27
-21
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+21
-12
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+26
-15
No files found.
transformer_engine/pytorch/module/batched_linear.py
View file @
9711d439
...
...
@@ -13,6 +13,7 @@ import transformer_engine_torch as tex
from
.base
import
(
get_multi_stream_cublas_batchgemm_workspace
,
get_dummy_wgrad
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
...
...
@@ -94,6 +95,7 @@ class _BatchLinear(torch.autograd.Function):
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
is_grad_enabled
:
bool
,
fine_grained_activation_offloading
,
*
weights_and_biases
:
Union
[
Float8Tensor
,
torch
.
Tensor
,
None
],
)
->
torch
.
Tensor
:
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
...
...
@@ -158,19 +160,24 @@ class _BatchLinear(torch.autograd.Function):
if
t
is
not
None
:
t
.
activation_offloading
=
True
offload_activation
=
False
if
hasattr
(
inp
,
"offloading_activation"
):
offload_activation
=
True
for
i
in
range
(
num_gemms
):
saved_inputmats
[
i
].
offloading_activation
=
inp
.
offloading_activation
ctx
.
offload_activation
=
offload_activation
for
i
in
range
(
num_gemms
):
weights
[
i
].
offloading_activation
=
False
weights
[
i
].
main_grad
.
offloading_activation
=
False
if
weights_fp8
[
i
]
is
not
None
:
weights_fp8
[
i
].
offloading_activation
=
False
if
offload_activation
and
cpu_offloading
:
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
fine_grained_activation_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
if
offload_activation
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
:
if
(
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
):
grad_added_to_main_grad_list
=
[]
for
weight
in
weights
:
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
...
...
@@ -187,7 +194,7 @@ class _BatchLinear(torch.autograd.Function):
*
weights
,
*
weights_fp8
,
*
[
w
.
main_grad
if
(
cpu_offloading
or
offloa
d_activation
)
and
fuse_wgrad_accumulation
else
None
w
.
main_grad
if
(
cpu_offloading
or
fine_graine
d_activation
_offloading
)
and
fuse_wgrad_accumulation
else
None
for
w
in
weights
],
)
...
...
@@ -226,12 +233,12 @@ class _BatchLinear(torch.autograd.Function):
weights
=
saved_tensors
[
2
*
ctx
.
num_gemms
:
3
*
ctx
.
num_gemms
]
weights_fp8
=
saved_tensors
[
3
*
ctx
.
num_gemms
:
4
*
ctx
.
num_gemms
]
main_grads
=
saved_tensors
[
4
*
ctx
.
num_gemms
:]
if
(
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
)
and
ctx
.
fuse_wgrad_accumulation
:
if
(
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
)
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
if
ctx
.
offloa
d_activation
and
weights
[
i
].
requires_grad
:
if
ctx
.
fine_graine
d_activation
_offloading
and
weights
[
i
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GRAD_OUTPUT
...
...
@@ -304,18 +311,15 @@ class _BatchLinear(torch.autograd.Function):
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
w
.
main_grad
.
shape
),
w
.
dtype
,
zero
=
True
,
)
else
:
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
w
.
main_grad
.
shape
),
w
.
dtype
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
...
...
@@ -367,6 +371,7 @@ class _BatchLinear(torch.autograd.Function):
None
,
# activation_dtype
None
,
# parallel_mode
None
,
# is_grad_enabled
None
,
# fine_grained_activation_offloading
*
wgrad_list
,
*
([
None
]
*
ctx
.
num_gemms
),
# weights_fp8
*
grad_biases
,
...
...
@@ -457,6 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
fine_grained_activation_offloading
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
)
->
None
:
...
...
@@ -480,6 +486,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
...
...
@@ -657,6 +665,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
activation_dtype
,
self
.
parallel_mode
,
torch
.
is_grad_enabled
(),
self
.
fine_grained_activation_offloading
,
*
weight_tensors
,
*
weight_tensors_fp8
,
*
bias_tensors
,
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
9711d439
...
...
@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
.base
import
(
get_multi_stream_cublas_workspace
,
get_dummy_wgrad
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
...
...
@@ -82,6 +83,7 @@ class _GroupedLinear(torch.autograd.Function):
module
,
skip_fp8_weight_update
,
save_original_input
,
fine_grained_activation_offloading
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
...
...
@@ -211,19 +213,22 @@ class _GroupedLinear(torch.autograd.Function):
if
isinstance
(
weight
,
QuantizedTensorBase
):
weight
.
update_usage
(
columnwise_usage
=
True
)
offload_activation
=
False
if
hasattr
(
inp
,
"offloading_activation"
):
offload_activation
=
True
for
i
in
range
(
num_gemms
):
inputmats
[
i
].
offloading_activation
=
inp
.
offloading_activation
ctx
.
offload_activation
=
offload_activation
for
i
in
range
(
num_gemms
):
weights
[
i
].
offloading_activation
=
False
weights_fp8
[
i
].
offloading_activation
=
False
biases
[
i
].
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
offloa
d_activation
and
cpu_offloading
:
if
fine_graine
d_activation
_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
if
offload_activation
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
:
if
(
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
):
grad_added_to_main_grad_list
=
[]
for
weight
in
weights
:
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
...
...
@@ -295,12 +300,12 @@ class _GroupedLinear(torch.autograd.Function):
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad_funcs
]
if
(
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
)
and
ctx
.
fuse_wgrad_accumulation
:
if
(
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
)
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
if
ctx
.
offloa
d_activation
and
weights
[
0
].
requires_grad
:
if
ctx
.
fine_graine
d_activation
_offloading
and
weights
[
0
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
# Preprocess grad output
...
...
@@ -452,18 +457,15 @@ class _GroupedLinear(torch.autograd.Function):
):
weight
.
grad_added_to_main_grad
=
True
if
getattr
(
weight
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
weight
.
main_grad
.
shape
,
dtype
=
weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
weight
.
main_grad
.
shape
),
weight
.
dtype
,
zero
=
True
,
)
else
:
wgrad
=
torch
.
empty
(
weight
.
main_grad
.
shape
,
dtype
=
weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
weight
.
main_grad
.
shape
),
weight
.
dtype
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
...
...
@@ -514,6 +516,7 @@ class _GroupedLinear(torch.autograd.Function):
None
,
None
,
None
,
None
,
*
wgrad_list
,
*
grad_biases
,
)
...
...
@@ -595,6 +598,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
save_original_input
:
bool
=
False
,
)
->
None
:
...
...
@@ -617,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule):
),
"GroupedLinear doesn't support Userbuffer overlap."
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
...
...
@@ -836,6 +841,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self
,
skip_fp8_weight_update
,
self
.
save_original_input
,
self
.
fine_grained_activation_offloading
,
*
weight_tensors
,
*
bias_tensors
,
)
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
9711d439
...
...
@@ -130,6 +130,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_wgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_name
:
str
,
fine_grained_activation_offloading
:
bool
,
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
...
...
@@ -435,21 +436,25 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
offload_activation
=
False
if
get_activation_offloading
():
offload_activation
=
True
if
not
inputmat
.
is_contiguous
():
inputmat
=
inputmat
.
contiguous
()
inputmat
.
offloading_activation
=
True
# Do not offload weights and biases
weight
.
offloading_activation
=
False
weightmat
.
offloading_activation
=
False
if
bias
is
not
None
:
bias
.
offloading_activation
=
False
ln_weight
.
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
ctx
.
offload_activation
=
offload_activation
if
offloa
d_activation
and
cpu_offloading
:
if
fine_graine
d_activation
_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
if
offload_activation
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
:
if
(
fine_grained_activation_offloading
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
):
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
...
...
@@ -594,10 +599,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_grad_added_to_main_grad
:
origin_weight
=
ctx
.
weight_object
if
ctx
.
offloa
d_activation
:
if
ctx
.
fine_graine
d_activation
_offloading
:
origin_weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
origin_weight
.
main_grad
=
main_grad
...
...
@@ -1074,6 +1079,7 @@ class _LayerNormLinear(torch.autograd.Function):
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fine_grained_activation_offloading
None
,
# fsdp_group
None
,
# debug
None
,
# module
...
...
@@ -1209,6 +1215,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
name
:
str
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
...
...
@@ -1227,6 +1234,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
name
=
name
...
...
@@ -1630,6 +1638,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
ub_bulk_wgrad
,
self
.
ub_bulk_dgrad
,
self
.
ub_name
,
self
.
fine_grained_activation_offloading
,
self
.
fsdp_group
,
self
,
skip_fp8_weight_update
,
...
...
transformer_engine/pytorch/module/linear.py
View file @
9711d439
...
...
@@ -111,6 +111,7 @@ class _Linear(torch.autograd.Function):
ub_bulk_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_name
:
str
,
fine_grained_activation_offloading
:
bool
,
fp8_output
:
bool
,
# pylint: disable=unused-argument
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
...
...
@@ -404,25 +405,25 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
offload_activation
=
False
if
get_activation_offloading
():
offload_activation
=
True
if
not
saved_inputmat
.
is_contiguous
():
saved_inputmat
=
saved_inputmat
.
contiguous
()
saved_inputmat
.
offloading_activation
=
True
ctx
.
offload_activation
=
offload_activation
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
offloa
d_activation
and
cpu_offloading
:
if
fine_graine
d_activation
_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
if
offload_activation
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
:
ctx
.
has_grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
has_grad_added_to_main_grad
:
if
(
fine_grained_activation_offloading
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
):
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
ctx
.
weight_object
=
weight
weight
.
grad_added_to_main_grad
=
True
ctx
.
weight_object
=
weight
else
:
ctx
.
has_grad_added_to_main_grad
=
False
if
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
ctx
.
has_grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
...
...
@@ -435,6 +436,12 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
# Do not offload weights and biases
weight
.
offloading_activation
=
False
weightmat
.
offloading_activation
=
False
if
bias
is
not
None
:
bias
.
offloading_activation
=
False
# TODO(ksivamani): Check memory usage
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
saved_inputmat
,
...
...
@@ -522,10 +529,10 @@ class _Linear(torch.autograd.Function):
else
None
)
if
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_grad_added_to_main_grad
:
weight
=
ctx
.
weight_object
if
ctx
.
offloa
d_activation
:
if
ctx
.
fine_graine
d_activation
_offloading
:
weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
weight
.
main_grad
=
main_grad
...
...
@@ -1009,6 +1016,7 @@ class _Linear(torch.autograd.Function):
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fine_grained_activation_offloading
None
,
# fp8_output
None
,
# fsdp_group
None
,
# module
...
...
@@ -1131,6 +1139,7 @@ class Linear(TransformerEngineBaseModule):
symmetric_ar_type
:
Optional
[
str
]
=
None
,
save_original_input
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
...
...
@@ -1146,6 +1155,7 @@ class Linear(TransformerEngineBaseModule):
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
save_original_input
=
save_original_input
self
.
name
=
name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
...
...
@@ -1493,6 +1503,7 @@ class Linear(TransformerEngineBaseModule):
self
.
ub_bulk_dgrad
,
self
.
ub_bulk_wgrad
,
self
.
ub_name
,
self
.
fine_grained_activation_offloading
,
fp8_output
,
self
.
fsdp_group
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment