Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9711d439
Commit
9711d439
authored
Oct 17, 2025
by
dongcl
Browse files
Update activation offload code to align with the official version
parent
712d526a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
70 deletions
+105
-70
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+31
-22
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+27
-21
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+21
-12
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+26
-15
No files found.
transformer_engine/pytorch/module/batched_linear.py
View file @
9711d439
...
@@ -13,6 +13,7 @@ import transformer_engine_torch as tex
...
@@ -13,6 +13,7 @@ import transformer_engine_torch as tex
from
.base
import
(
from
.base
import
(
get_multi_stream_cublas_batchgemm_workspace
,
get_multi_stream_cublas_batchgemm_workspace
,
get_dummy_wgrad
,
TransformerEngineBaseModule
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
...
@@ -94,6 +95,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -94,6 +95,7 @@ class _BatchLinear(torch.autograd.Function):
activation_dtype
:
torch
.
dtype
,
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
parallel_mode
:
Union
[
str
,
None
],
is_grad_enabled
:
bool
,
is_grad_enabled
:
bool
,
fine_grained_activation_offloading
,
*
weights_and_biases
:
Union
[
Float8Tensor
,
torch
.
Tensor
,
None
],
*
weights_and_biases
:
Union
[
Float8Tensor
,
torch
.
Tensor
,
None
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
...
@@ -158,19 +160,24 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -158,19 +160,24 @@ class _BatchLinear(torch.autograd.Function):
if
t
is
not
None
:
if
t
is
not
None
:
t
.
activation_offloading
=
True
t
.
activation_offloading
=
True
offload_activation
=
False
for
i
in
range
(
num_gemms
):
if
hasattr
(
inp
,
"offloading_activation"
):
weights
[
i
].
offloading_activation
=
False
offload_activation
=
True
weights
[
i
].
main_grad
.
offloading_activation
=
False
for
i
in
range
(
num_gemms
):
if
weights_fp8
[
i
]
is
not
None
:
saved_inputmats
[
i
].
offloading_activation
=
inp
.
offloading_activation
weights_fp8
[
i
].
offloading_activation
=
False
ctx
.
offload_activation
=
offload_activation
if
offload_activation
and
cpu_offloading
:
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
fine_grained_activation_offloading
and
cpu_offloading
:
raise
ValueError
(
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
)
if
offload_activation
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
:
if
(
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
):
grad_added_to_main_grad_list
=
[]
grad_added_to_main_grad_list
=
[]
for
weight
in
weights
:
for
weight
in
weights
:
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
...
@@ -187,7 +194,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -187,7 +194,7 @@ class _BatchLinear(torch.autograd.Function):
*
weights
,
*
weights
,
*
weights_fp8
,
*
weights_fp8
,
*
[
*
[
w
.
main_grad
if
(
cpu_offloading
or
offloa
d_activation
)
and
fuse_wgrad_accumulation
else
None
w
.
main_grad
if
(
cpu_offloading
or
fine_graine
d_activation
_offloading
)
and
fuse_wgrad_accumulation
else
None
for
w
in
weights
for
w
in
weights
],
],
)
)
...
@@ -226,12 +233,12 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -226,12 +233,12 @@ class _BatchLinear(torch.autograd.Function):
weights
=
saved_tensors
[
2
*
ctx
.
num_gemms
:
3
*
ctx
.
num_gemms
]
weights
=
saved_tensors
[
2
*
ctx
.
num_gemms
:
3
*
ctx
.
num_gemms
]
weights_fp8
=
saved_tensors
[
3
*
ctx
.
num_gemms
:
4
*
ctx
.
num_gemms
]
weights_fp8
=
saved_tensors
[
3
*
ctx
.
num_gemms
:
4
*
ctx
.
num_gemms
]
main_grads
=
saved_tensors
[
4
*
ctx
.
num_gemms
:]
main_grads
=
saved_tensors
[
4
*
ctx
.
num_gemms
:]
if
(
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
)
and
ctx
.
fuse_wgrad_accumulation
:
if
(
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
)
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
for
i
in
range
(
ctx
.
num_gemms
):
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
weights
[
i
]
=
w
if
ctx
.
offloa
d_activation
and
weights
[
i
].
requires_grad
:
if
ctx
.
fine_graine
d_activation
_offloading
and
weights
[
i
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GRAD_OUTPUT
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GRAD_OUTPUT
...
@@ -304,18 +311,15 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -304,18 +311,15 @@ class _BatchLinear(torch.autograd.Function):
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
wgrad
=
get_dummy_wgrad
(
w
.
main_grad
.
shape
,
list
(
w
.
main_grad
.
shape
),
dtype
=
w
.
dtype
,
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
zero
=
True
,
requires_grad
=
False
,
)
)
else
:
else
:
wgrad
=
torch
.
empty
(
wgrad
=
get_dummy_wgrad
(
w
.
main_grad
.
shape
,
list
(
w
.
main_grad
.
shape
),
dtype
=
w
.
dtype
,
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
)
elif
ctx
.
fuse_wgrad_accumulation
:
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
wgrad
=
None
...
@@ -367,6 +371,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -367,6 +371,7 @@ class _BatchLinear(torch.autograd.Function):
None
,
# activation_dtype
None
,
# activation_dtype
None
,
# parallel_mode
None
,
# parallel_mode
None
,
# is_grad_enabled
None
,
# is_grad_enabled
None
,
# fine_grained_activation_offloading
*
wgrad_list
,
*
wgrad_list
,
*
([
None
]
*
ctx
.
num_gemms
),
# weights_fp8
*
([
None
]
*
ctx
.
num_gemms
),
# weights_fp8
*
grad_biases
,
*
grad_biases
,
...
@@ -457,6 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -457,6 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
fine_grained_activation_offloading
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
)
->
None
:
)
->
None
:
...
@@ -480,6 +486,8 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -480,6 +486,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
rng_tracker_name
=
rng_tracker_name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
...
@@ -657,6 +665,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -657,6 +665,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
activation_dtype
,
self
.
activation_dtype
,
self
.
parallel_mode
,
self
.
parallel_mode
,
torch
.
is_grad_enabled
(),
torch
.
is_grad_enabled
(),
self
.
fine_grained_activation_offloading
,
*
weight_tensors
,
*
weight_tensors
,
*
weight_tensors_fp8
,
*
weight_tensors_fp8
,
*
bias_tensors
,
*
bias_tensors
,
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
9711d439
...
@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
...
@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.common.recipe
import
Recipe
from
.base
import
(
from
.base
import
(
get_multi_stream_cublas_workspace
,
get_multi_stream_cublas_workspace
,
get_dummy_wgrad
,
TransformerEngineBaseModule
,
TransformerEngineBaseModule
,
_2X_ACC_FPROP
,
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
...
@@ -82,6 +83,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -82,6 +83,7 @@ class _GroupedLinear(torch.autograd.Function):
module
,
module
,
skip_fp8_weight_update
,
skip_fp8_weight_update
,
save_original_input
,
save_original_input
,
fine_grained_activation_offloading
,
*
weights_and_biases
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
...
@@ -211,19 +213,22 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -211,19 +213,22 @@ class _GroupedLinear(torch.autograd.Function):
if
isinstance
(
weight
,
QuantizedTensorBase
):
if
isinstance
(
weight
,
QuantizedTensorBase
):
weight
.
update_usage
(
columnwise_usage
=
True
)
weight
.
update_usage
(
columnwise_usage
=
True
)
offload_activation
=
False
for
i
in
range
(
num_gemms
):
if
hasattr
(
inp
,
"offloading_activation"
):
weights
[
i
].
offloading_activation
=
False
offload_activation
=
True
weights_fp8
[
i
].
offloading_activation
=
False
for
i
in
range
(
num_gemms
):
biases
[
i
].
offloading_activation
=
False
inputmats
[
i
].
offloading_activation
=
inp
.
offloading_activation
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
ctx
.
offload_activation
=
offload_activation
if
offloa
d_activation
and
cpu_offloading
:
if
fine_graine
d_activation
_offloading
and
cpu_offloading
:
raise
ValueError
(
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
)
if
offload_activation
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
:
if
(
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
):
grad_added_to_main_grad_list
=
[]
grad_added_to_main_grad_list
=
[]
for
weight
in
weights
:
for
weight
in
weights
:
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
...
@@ -295,12 +300,12 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -295,12 +300,12 @@ class _GroupedLinear(torch.autograd.Function):
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad_funcs
]
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad_funcs
]
if
(
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
)
and
ctx
.
fuse_wgrad_accumulation
:
if
(
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
)
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
for
i
in
range
(
ctx
.
num_gemms
):
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
weights
[
i
]
=
w
if
ctx
.
offloa
d_activation
and
weights
[
0
].
requires_grad
:
if
ctx
.
fine_graine
d_activation
_offloading
and
weights
[
0
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
# Preprocess grad output
# Preprocess grad output
...
@@ -452,18 +457,15 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -452,18 +457,15 @@ class _GroupedLinear(torch.autograd.Function):
):
):
weight
.
grad_added_to_main_grad
=
True
weight
.
grad_added_to_main_grad
=
True
if
getattr
(
weight
,
"zero_out_wgrad"
,
False
):
if
getattr
(
weight
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
wgrad
=
get_dummy_wgrad
(
weight
.
main_grad
.
shape
,
list
(
weight
.
main_grad
.
shape
),
dtype
=
weight
.
dtype
,
weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
zero
=
True
,
requires_grad
=
False
,
)
)
else
:
else
:
wgrad
=
torch
.
empty
(
wgrad
=
get_dummy_wgrad
(
weight
.
main_grad
.
shape
,
list
(
weight
.
main_grad
.
shape
),
dtype
=
weight
.
dtype
,
weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
)
elif
ctx
.
fuse_wgrad_accumulation
:
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
wgrad
=
None
...
@@ -514,6 +516,7 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -514,6 +516,7 @@ class _GroupedLinear(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
wgrad_list
,
*
wgrad_list
,
*
grad_biases
,
*
grad_biases
,
)
)
...
@@ -595,6 +598,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -595,6 +598,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs
:
bool
=
False
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
ub_name
:
Optional
[
str
]
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
save_original_input
:
bool
=
False
,
save_original_input
:
bool
=
False
,
)
->
None
:
)
->
None
:
...
@@ -617,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -617,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule):
),
"GroupedLinear doesn't support Userbuffer overlap."
),
"GroupedLinear doesn't support Userbuffer overlap."
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
rng_tracker_name
=
rng_tracker_name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
...
@@ -836,6 +841,7 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -836,6 +841,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self
,
self
,
skip_fp8_weight_update
,
skip_fp8_weight_update
,
self
.
save_original_input
,
self
.
save_original_input
,
self
.
fine_grained_activation_offloading
,
*
weight_tensors
,
*
weight_tensors
,
*
bias_tensors
,
*
bias_tensors
,
)
)
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
9711d439
...
@@ -130,6 +130,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -130,6 +130,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_wgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_name
:
str
,
ub_name
:
str
,
fine_grained_activation_offloading
:
bool
,
fsdp_group
:
Union
[
dist_group_type
,
None
],
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
skip_fp8_weight_update
:
bool
,
...
@@ -435,21 +436,25 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -435,21 +436,25 @@ class _LayerNormLinear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
offload_activation
=
False
# Do not offload weights and biases
if
get_activation_offloading
():
weight
.
offloading_activation
=
False
offload_activation
=
True
weightmat
.
offloading_activation
=
False
if
not
inputmat
.
is_contiguous
():
if
bias
is
not
None
:
inputmat
=
inputmat
.
contiguous
()
bias
.
offloading_activation
=
False
inputmat
.
offloading_activation
=
True
ln_weight
.
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
ctx
.
offload_activation
=
offload_activation
if
offloa
d_activation
and
cpu_offloading
:
if
fine_graine
d_activation
_offloading
and
cpu_offloading
:
raise
ValueError
(
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
)
if
offload_activation
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
:
if
(
fine_grained_activation_offloading
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
):
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
...
@@ -594,10 +599,10 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -594,10 +599,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
# we need to connect them into one.
if
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_grad_added_to_main_grad
:
if
ctx
.
has_grad_added_to_main_grad
:
origin_weight
=
ctx
.
weight_object
origin_weight
=
ctx
.
weight_object
if
ctx
.
offloa
d_activation
:
if
ctx
.
fine_graine
d_activation
_offloading
:
origin_weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
origin_weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
origin_weight
.
main_grad
=
main_grad
origin_weight
.
main_grad
=
main_grad
...
@@ -1074,6 +1079,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -1074,6 +1079,7 @@ class _LayerNormLinear(torch.autograd.Function):
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# ub_name
None
,
# fine_grained_activation_offloading
None
,
# fsdp_group
None
,
# fsdp_group
None
,
# debug
None
,
# debug
None
,
# module
None
,
# module
...
@@ -1209,6 +1215,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1209,6 +1215,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
delay_wgrad_compute
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
name
:
str
=
None
,
name
:
str
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -1227,6 +1234,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1227,6 +1234,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
)
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
name
=
name
self
.
name
=
name
...
@@ -1630,6 +1638,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1630,6 +1638,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
ub_bulk_wgrad
,
self
.
ub_bulk_wgrad
,
self
.
ub_bulk_dgrad
,
self
.
ub_bulk_dgrad
,
self
.
ub_name
,
self
.
ub_name
,
self
.
fine_grained_activation_offloading
,
self
.
fsdp_group
,
self
.
fsdp_group
,
self
,
self
,
skip_fp8_weight_update
,
skip_fp8_weight_update
,
...
...
transformer_engine/pytorch/module/linear.py
View file @
9711d439
...
@@ -111,6 +111,7 @@ class _Linear(torch.autograd.Function):
...
@@ -111,6 +111,7 @@ class _Linear(torch.autograd.Function):
ub_bulk_dgrad
:
bool
,
ub_bulk_dgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_bulk_wgrad
:
bool
,
ub_name
:
str
,
ub_name
:
str
,
fine_grained_activation_offloading
:
bool
,
fp8_output
:
bool
,
# pylint: disable=unused-argument
fp8_output
:
bool
,
# pylint: disable=unused-argument
fsdp_group
:
Union
[
dist_group_type
,
None
],
fsdp_group
:
Union
[
dist_group_type
,
None
],
module
:
torch
.
nn
.
Module
,
module
:
torch
.
nn
.
Module
,
...
@@ -404,25 +405,25 @@ class _Linear(torch.autograd.Function):
...
@@ -404,25 +405,25 @@ class _Linear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
offload_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
get_activation_offloading
():
offload_activation
=
True
if
not
saved_inputmat
.
is_contiguous
():
saved_inputmat
=
saved_inputmat
.
contiguous
()
saved_inputmat
.
offloading_activation
=
True
ctx
.
offload_activation
=
offload_activation
if
offloa
d_activation
and
cpu_offloading
:
if
fine_graine
d_activation
_offloading
and
cpu_offloading
:
raise
ValueError
(
raise
ValueError
(
f
"Do not use
offloa
d_activation and cpu_offloading at the same time."
f
"Do not use
fine_graine
d_activation
_offloading
and cpu_offloading at the same time."
)
)
if
offload_activation
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
:
if
(
ctx
.
has_grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
fine_grained_activation_offloading
if
ctx
.
has_grad_added_to_main_grad
:
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
):
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
ctx
.
weight_object
=
weight
weight
.
grad_added_to_main_grad
=
True
weight
.
grad_added_to_main_grad
=
True
ctx
.
weight_object
=
weight
else
:
ctx
.
has_grad_added_to_main_grad
=
False
if
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
ctx
.
has_grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
ctx
.
has_grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
...
@@ -435,6 +436,12 @@ class _Linear(torch.autograd.Function):
...
@@ -435,6 +436,12 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
ctx
.
weight_object
=
weight
# Do not offload weights and biases
weight
.
offloading_activation
=
False
weightmat
.
offloading_activation
=
False
if
bias
is
not
None
:
bias
.
offloading_activation
=
False
# TODO(ksivamani): Check memory usage
# TODO(ksivamani): Check memory usage
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
saved_inputmat
,
saved_inputmat
,
...
@@ -522,10 +529,10 @@ class _Linear(torch.autograd.Function):
...
@@ -522,10 +529,10 @@ class _Linear(torch.autograd.Function):
else
None
else
None
)
)
if
ctx
.
cpu_offloading
or
ctx
.
offloa
d_activation
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
cpu_offloading
or
ctx
.
fine_graine
d_activation
_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_grad_added_to_main_grad
:
if
ctx
.
has_grad_added_to_main_grad
:
weight
=
ctx
.
weight_object
weight
=
ctx
.
weight_object
if
ctx
.
offloa
d_activation
:
if
ctx
.
fine_graine
d_activation
_offloading
:
weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
weight
.
main_grad
=
main_grad
weight
.
main_grad
=
main_grad
...
@@ -1009,6 +1016,7 @@ class _Linear(torch.autograd.Function):
...
@@ -1009,6 +1016,7 @@ class _Linear(torch.autograd.Function):
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_dgrad
None
,
# ub_bulk_wgrad
None
,
# ub_bulk_wgrad
None
,
# ub_name
None
,
# ub_name
None
,
# fine_grained_activation_offloading
None
,
# fp8_output
None
,
# fp8_output
None
,
# fsdp_group
None
,
# fsdp_group
None
,
# module
None
,
# module
...
@@ -1131,6 +1139,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1131,6 +1139,7 @@ class Linear(TransformerEngineBaseModule):
symmetric_ar_type
:
Optional
[
str
]
=
None
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
save_original_input
:
bool
=
False
,
save_original_input
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -1146,6 +1155,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1146,6 +1155,7 @@ class Linear(TransformerEngineBaseModule):
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
save_original_input
=
save_original_input
self
.
save_original_input
=
save_original_input
self
.
name
=
name
self
.
name
=
name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
...
@@ -1493,6 +1503,7 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1493,6 +1503,7 @@ class Linear(TransformerEngineBaseModule):
self
.
ub_bulk_dgrad
,
self
.
ub_bulk_dgrad
,
self
.
ub_bulk_wgrad
,
self
.
ub_bulk_wgrad
,
self
.
ub_name
,
self
.
ub_name
,
self
.
fine_grained_activation_offloading
,
fp8_output
,
fp8_output
,
self
.
fsdp_group
,
self
.
fsdp_group
,
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment