Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9711d439
"docs/vscode:/vscode.git/clone" did not exist on "c81dddb45c71e630b907f9d84686ecd73b4105c7"
Commit
9711d439
authored
Oct 17, 2025
by
dongcl
Browse files
Update activation offload code to align with the official version
parent
712d526a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
70 deletions
+105
-70
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+31
-22
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+27
-21
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+21
-12
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+26
-15
No files found.
transformer_engine/pytorch/module/batched_linear.py
View file @
9711d439
...
...
@@ -13,6 +13,7 @@ import transformer_engine_torch as tex
from
.base
import
(
get_multi_stream_cublas_batchgemm_workspace
,
get_dummy_wgrad
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
...
...
@@ -94,6 +95,7 @@ class _BatchLinear(torch.autograd.Function):
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
is_grad_enabled
:
bool
,
fine_grained_activation_offloading
,
*
weights_and_biases
:
Union
[
Float8Tensor
,
torch
.
Tensor
,
None
],
)
->
torch
.
Tensor
:
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
...
...
@@ -158,19 +160,24 @@ class _BatchLinear(torch.autograd.Function):
if
t
is
not
None
:
t
.
activation_offloading
=
True
offload_activation
=
False
if
hasattr
(
inp
,
"offloading_activation"
):
offload_activation
=
True
for
i
in
range
(
num_gemms
):
saved_inputmats
[
i
].
offloading_activation
=
inp
.
offloading_activation
ctx
.
offload_activation
=
offload_activation
weights
[
i
].
offloading_activation
=
False
weights
[
i
].
main_grad
.
offloading_activation
=
False
if
weights_fp8
[
i
]
is
not
None
:
weights_fp8
[
i
].
offloading_activation
=
False
if
offload_activation
and
cpu_offloading
:
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
fine_grained_activation_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
if
offload_activation
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
:
if
(
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
):
grad_added_to_main_grad_list
=
[]
for
weight
in
weights
:
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
...
...
@@ -187,7 +194,7 @@ class _BatchLinear(torch.autograd.Function):
*
weights
,
*
weights_fp8
,
*
[
w
.
main_grad
if
(
cpu_offloading
or
offloa
d_activation
)
and
fuse_wgrad_accumulation
else
None
w
.
main_grad
if
(
cpu_offloading
or
fine_graine
d_activation
_offloading
)
and
fuse_wgrad_accumulation
else
None
for
w
in
weights
],
)
...
...
@@ -226,12 +233,12 @@ class _BatchLinear(torch.autograd.Function):
weights
=
saved_tensors
[
2
*
ctx
.
num_gemms
:
3
*
ctx
.
num_gemms
]
weights_fp8
=
saved_tensors
[
3
*
ctx
.
num_gemms
:
4
*
ctx
.
num_gemms
]
main_grads
=
saved_tensors
[
4
*
ctx
.
num_gemms
:]
if
(
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
)
and
ctx
.
fuse_wgrad_accumulation
:
if
(
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
)
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
if
ctx
.
offloa
d_activation
and
weights
[
i
].
requires_grad
:
if
ctx
.
fine_graine
d_activation
_offloading
and
weights
[
i
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GRAD_OUTPUT
...
...
@@ -304,18 +311,15 @@ class _BatchLinear(torch.autograd.Function):
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
w
.
main_grad
.
shape
),
w
.
dtype
,
zero
=
True
,
)
else
:
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
w
.
main_grad
.
shape
),
w
.
dtype
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
...
...
@@ -367,6 +371,7 @@ class _BatchLinear(torch.autograd.Function):
None
,
# activation_dtype
None
,
# parallel_mode
None
,
# is_grad_enabled
None
,
# fine_grained_activation_offloading
*
wgrad_list
,
*
([
None
]
*
ctx
.
num_gemms
),
# weights_fp8
*
grad_biases
,
...
...
@@ -457,6 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
fine_grained_activation_offloading
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
)
->
None
:
...
...
@@ -480,6 +486,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
...
...
@@ -657,6 +665,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
activation_dtype
,
self
.
parallel_mode
,
torch
.
is_grad_enabled
(),
self
.
fine_grained_activation_offloading
,
*
weight_tensors
,
*
weight_tensors_fp8
,
*
bias_tensors
,
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
9711d439
...
...
@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
.base
import
(
get_multi_stream_cublas_workspace
,
get_dummy_wgrad
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
...
...
@@ -82,6 +83,7 @@ class _GroupedLinear(torch.autograd.Function):
module
,
skip_fp8_weight_update
,
save_original_input
,
fine_grained_activation_offloading
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
...
...
@@ -211,19 +213,22 @@ class _GroupedLinear(torch.autograd.Function):
if
isinstance
(
weight
,
QuantizedTensorBase
):
weight
.
update_usage
(
columnwise_usage
=
True
)
offload_activation
=
False
if
hasattr
(
inp
,
"offloading_activation"
):
offload_activation
=
True
for
i
in
range
(
num_gemms
):
inputmats
[
i
].
offloading_activation
=
inp
.
offloading_activation
ctx
.
offload_activation
=
offload_activation
weights
[
i
].
offloading_activation
=
False
weights_fp8
[
i
].
offloading_activation
=
False
biases
[
i
].
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
offloa
d_activation
and
cpu_offloading
:
if
fine_graine
d_activation
_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
if
offload_activation
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
:
if
(
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
):
grad_added_to_main_grad_list
=
[]
for
weight
in
weights
:
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
...
...
@@ -295,12 +300,12 @@ class _GroupedLinear(torch.autograd.Function):
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad_funcs
]
if
(
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
)
and
ctx
.
fuse_wgrad_accumulation
:
if
(
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
)
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
if
ctx
.
offloa
d_activation
and
weights
[
0
].
requires_grad
:
if
ctx
.
fine_graine
d_activation
_offloading
and
weights
[
0
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
# Preprocess grad output
...
...
@@ -452,18 +457,15 @@ class _GroupedLinear(torch.autograd.Function):
):
weight
.
grad_added_to_main_grad
=
True
if
getattr
(
weight
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
weight
.
main_grad
.
shape
,
dtype
=
weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
weight
.
main_grad
.
shape
),
weight
.
dtype
,
zero
=
True
,
)
else
:
wgrad
=
torch
.
empty
(
weight
.
main_grad
.
shape
,
dtype
=
weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
wgrad
=
get_dummy_wgrad
(
list
(
weight
.
main_grad
.
shape
),
weight
.
dtype
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
...
...
@@ -514,6 +516,7 @@ class _GroupedLinear(torch.autograd.Function):
None
,
None
,
None
,
None
,
*
wgrad_list
,
*
grad_biases
,
)
...
...
@@ -595,6 +598,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
save_original_input
:
bool
=
False
,
)
->
None
:
...
...
@@ -617,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule):
),
"GroupedLinear doesn't support Userbuffer overlap."
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
...
...
@@ -836,6 +841,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self
,
skip_fp8_weight_update
,
self
.
save_original_input
,
self
.
fine_grained_activation_offloading
,
*
weight_tensors
,
*
bias_tensors
,
)
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
9711d439
...
...
@@ -130,6 +130,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_wgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_name
:
str
,
fine_grained_activation_offloading
:
bool
,
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
...
...
@@ -435,21 +436,25 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
offload_activation
=
False
if
get_activation_offloading
():
offload_activation
=
True
if
not
inputmat
.
is_contiguous
():
inputmat
=
inputmat
.
contiguous
()
inputmat
.
offloading_activation
=
True
# Do not offload weights and biases
weight
.
offloading_activation
=
False
weightmat
.
offloading_activation
=
False
if
bias
is
not
None
:
bias
.
offloading_activation
=
False
ln_weight
.
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
ctx
.
offload_activation
=
offload_activation
if
offloa
d_activation
and
cpu_offloading
:
if
fine_graine
d_activation
_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
if
offload_activation
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
:
if
(
fine_grained_activation_offloading
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
):
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
...
...
@@ -594,10 +599,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_grad_added_to_main_grad
:
origin_weight
=
ctx
.
weight_object
if
ctx
.
offloa
d_activation
:
if
ctx
.
fine_graine
d_activation
_offloading
:
origin_weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
origin_weight
.
main_grad
=
main_grad
...
...
@@ -1074,6 +1079,7 @@ class _LayerNormLinear(torch.autograd.Function):
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fine_grained_activation_offloading
None
,
# fsdp_group
None
,
# debug
None
,
# module
...
...
@@ -1209,6 +1215,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
name
:
str
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
...
...
@@ -1227,6 +1234,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
name
=
name
...
...
@@ -1630,6 +1638,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
ub_bulk_wgrad
,
self
.
ub_bulk_dgrad
,
self
.
ub_name
,
self
.
fine_grained_activation_offloading
,
self
.
fsdp_group
,
self
,
skip_fp8_weight_update
,
...
...
transformer_engine/pytorch/module/linear.py
View file @
9711d439
...
...
@@ -111,6 +111,7 @@ class _Linear(torch.autograd.Function):
ub_bulk_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_name
:
str
,
fine_grained_activation_offloading
:
bool
,
fp8_output
:
bool
,
# pylint: disable=unused-argument
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
...
...
@@ -404,25 +405,25 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
offload_activation
=
False
if
get_activation_offloading
():
offload_activation
=
True
if
not
saved_inputmat
.
is_contiguous
():
saved_inputmat
=
saved_inputmat
.
contiguous
()
saved_inputmat
.
offloading_activation
=
True
ctx
.
offload_activation
=
offload_activation
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
offloa
d_activation
and
cpu_offloading
:
if
fine_graine
d_activation
_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
if
offload_activation
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
:
ctx
.
has_grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
has_grad_added_to_main_grad
:
if
(
fine_grained_activation_offloading
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
):
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
ctx
.
weight_object
=
weight
weight
.
grad_added_to_main_grad
=
True
ctx
.
weight_object
=
weight
else
:
ctx
.
has_grad_added_to_main_grad
=
False
if
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
ctx
.
has_grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
...
...
@@ -435,6 +436,12 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
# Do not offload weights and biases
weight
.
offloading_activation
=
False
weightmat
.
offloading_activation
=
False
if
bias
is
not
None
:
bias
.
offloading_activation
=
False
# TODO(ksivamani): Check memory usage
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
saved_inputmat
,
...
...
@@ -522,10 +529,10 @@ class _Linear(torch.autograd.Function):
else
None
)
if
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_grad_added_to_main_grad
:
weight
=
ctx
.
weight_object
if
ctx
.
offloa
d_activation
:
if
ctx
.
fine_graine
d_activation
_offloading
:
weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
weight
.
main_grad
=
main_grad
...
...
@@ -1009,6 +1016,7 @@ class _Linear(torch.autograd.Function):
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# fine_grained_activation_offloading
None
,
# fp8_output
None
,
# fsdp_group
None
,
# module
...
...
@@ -1131,6 +1139,7 @@ class Linear(TransformerEngineBaseModule):
symmetric_ar_type
:
Optional
[
str
]
=
None
,
save_original_input
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
...
...
@@ -1146,6 +1155,7 @@ class Linear(TransformerEngineBaseModule):
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
save_original_input
=
save_original_input
self
.
name
=
name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
...
...
@@ -1493,6 +1503,7 @@ class Linear(TransformerEngineBaseModule):
self
.
ub_bulk_dgrad
,
self
.
ub_bulk_wgrad
,
self
.
ub_name
,
self
.
fine_grained_activation_offloading
,
fp8_output
,
self
.
fsdp_group
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment