Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
96a104d5
Commit
96a104d5
authored
Jan 21, 2026
by
wenjh
Browse files
Merge develop_v2.10
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
abec28e8
0fce42f7
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
38 additions
and
138 deletions
+38
-138
tests/cpp/operator/test_cast_float8blockwise.cu
tests/cpp/operator/test_cast_float8blockwise.cu
+3
-0
tests/cpp/operator/test_cast_mxfp8.cu
tests/cpp/operator/test_cast_mxfp8.cu
+3
-0
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
+3
-0
tests/cpp/operator/test_dequantize_mxfp8.cu
tests/cpp/operator/test_dequantize_mxfp8.cu
+3
-1
tests/cpp/test_common.h
tests/cpp/test_common.h
+3
-0
transformer_engine/common/multi_tensor/adam.cu
transformer_engine/common/multi_tensor/adam.cu
+3
-0
transformer_engine/common/multi_tensor/compute_scale.cu
transformer_engine/common/multi_tensor/compute_scale.cu
+3
-0
transformer_engine/common/multi_tensor/scale.cu
transformer_engine/common/multi_tensor/scale.cu
+3
-0
transformer_engine/common/multi_tensor/sgd.cu
transformer_engine/common/multi_tensor/sgd.cu
+3
-0
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+2
-37
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+1
-33
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+3
-9
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+5
-31
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+0
-27
No files found.
tests/cpp/operator/test_cast_float8blockwise.cu
View file @
96a104d5
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
...
...
tests/cpp/operator/test_cast_mxfp8.cu
View file @
96a104d5
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
...
...
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
View file @
96a104d5
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
...
...
tests/cpp/operator/test_dequantize_mxfp8.cu
View file @
96a104d5
...
@@ -10,7 +10,9 @@
...
@@ -10,7 +10,9 @@
#include <memory>
#include <memory>
#include <random>
#include <random>
#include <limits>
#include <limits>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
...
...
tests/cpp/test_common.h
View file @
96a104d5
...
@@ -15,6 +15,9 @@
...
@@ -15,6 +15,9 @@
#endif
#endif
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
...
...
transformer_engine/common/multi_tensor/adam.cu
View file @
96a104d5
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
************************************************************************/
************************************************************************/
#include <assert.h>
#include <assert.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transformer_engine.h>
...
...
transformer_engine/common/multi_tensor/compute_scale.cu
View file @
96a104d5
...
@@ -7,6 +7,9 @@
...
@@ -7,6 +7,9 @@
#include <limits>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <assert.h>
#include <assert.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transformer_engine.h>
...
...
transformer_engine/common/multi_tensor/scale.cu
View file @
96a104d5
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
************************************************************************/
************************************************************************/
#include <assert.h>
#include <assert.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_fp8.h>
#include <cuda_fp8.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/multi_tensor.h>
...
...
transformer_engine/common/multi_tensor/sgd.cu
View file @
96a104d5
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
************************************************************************/
************************************************************************/
#include <assert.h>
#include <assert.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transformer_engine.h>
...
...
transformer_engine/pytorch/module/batched_linear.py
View file @
96a104d5
...
@@ -95,7 +95,6 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -95,7 +95,6 @@ class _BatchLinear(torch.autograd.Function):
activation_dtype
:
torch
.
dtype
,
activation_dtype
:
torch
.
dtype
,
parallel_mode
:
Union
[
str
,
None
],
parallel_mode
:
Union
[
str
,
None
],
is_grad_enabled
:
bool
,
is_grad_enabled
:
bool
,
fine_grained_activation_offloading
,
*
weights_and_biases
:
Union
[
Float8Tensor
,
torch
.
Tensor
,
None
],
*
weights_and_biases
:
Union
[
Float8Tensor
,
torch
.
Tensor
,
None
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
...
@@ -160,33 +159,6 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -160,33 +159,6 @@ class _BatchLinear(torch.autograd.Function):
if
t
is
not
None
:
if
t
is
not
None
:
t
.
activation_offloading
=
True
t
.
activation_offloading
=
True
for
i
in
range
(
num_gemms
):
weights
[
i
].
offloading_activation
=
False
weights
[
i
].
main_grad
.
offloading_activation
=
False
if
weights_fp8
[
i
]
is
not
None
:
weights_fp8
[
i
].
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
fine_grained_activation_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if
(
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
):
grad_added_to_main_grad_list
=
[]
for
weight
in
weights
:
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
grad_added_to_main_grad_list
.
append
(
weight
.
grad_added_to_main_grad
)
weight
.
grad_added_to_main_grad
=
True
else
:
grad_added_to_main_grad_list
.
append
(
None
)
ctx
.
grad_added_to_main_grad_list
=
grad_added_to_main_grad_list
ctx
.
save_for_backward
(
ctx
.
save_for_backward
(
None
,
None
,
*
saved_inputmats
,
*
saved_inputmats
,
...
@@ -194,7 +166,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -194,7 +166,7 @@ class _BatchLinear(torch.autograd.Function):
*
weights
,
*
weights
,
*
weights_fp8
,
*
weights_fp8
,
*
[
*
[
w
.
main_grad
if
(
cpu_offloading
or
fine_grained_activation_offloading
)
and
fuse_wgrad_accumulation
else
None
w
.
main_grad
if
cpu_offloading
and
fuse_wgrad_accumulation
else
None
for
w
in
weights
for
w
in
weights
],
],
)
)
...
@@ -233,13 +205,11 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -233,13 +205,11 @@ class _BatchLinear(torch.autograd.Function):
weights
=
saved_tensors
[
2
*
ctx
.
num_gemms
:
3
*
ctx
.
num_gemms
]
weights
=
saved_tensors
[
2
*
ctx
.
num_gemms
:
3
*
ctx
.
num_gemms
]
weights_fp8
=
saved_tensors
[
3
*
ctx
.
num_gemms
:
4
*
ctx
.
num_gemms
]
weights_fp8
=
saved_tensors
[
3
*
ctx
.
num_gemms
:
4
*
ctx
.
num_gemms
]
main_grads
=
saved_tensors
[
4
*
ctx
.
num_gemms
:]
main_grads
=
saved_tensors
[
4
*
ctx
.
num_gemms
:]
if
(
ctx
.
cpu_offloading
or
ctx
.
fine_grained_activation_offloading
)
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
for
i
in
range
(
ctx
.
num_gemms
):
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
weights
[
i
]
=
w
if
ctx
.
fine_grained_activation_offloading
and
weights
[
i
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GRAD_OUTPUT
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GRAD_OUTPUT
grad_output
=
grad_output
.
contiguous
()
grad_output
=
grad_output
.
contiguous
()
...
@@ -371,7 +341,6 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -371,7 +341,6 @@ class _BatchLinear(torch.autograd.Function):
None
,
# activation_dtype
None
,
# activation_dtype
None
,
# parallel_mode
None
,
# parallel_mode
None
,
# is_grad_enabled
None
,
# is_grad_enabled
None
,
# fine_grained_activation_offloading
*
wgrad_list
,
*
wgrad_list
,
*
([
None
]
*
ctx
.
num_gemms
),
# weights_fp8
*
([
None
]
*
ctx
.
num_gemms
),
# weights_fp8
*
grad_biases
,
*
grad_biases
,
...
@@ -462,7 +431,6 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -462,7 +431,6 @@ class BatchedLinear(TransformerEngineBaseModule):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
fine_grained_activation_offloading
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
)
->
None
:
)
->
None
:
...
@@ -486,8 +454,6 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -486,8 +454,6 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
rng_tracker_name
=
rng_tracker_name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
...
@@ -665,7 +631,6 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -665,7 +631,6 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
activation_dtype
,
self
.
activation_dtype
,
self
.
parallel_mode
,
self
.
parallel_mode
,
torch
.
is_grad_enabled
(),
torch
.
is_grad_enabled
(),
self
.
fine_grained_activation_offloading
,
*
weight_tensors
,
*
weight_tensors
,
*
weight_tensors_fp8
,
*
weight_tensors_fp8
,
*
bias_tensors
,
*
bias_tensors
,
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
96a104d5
...
@@ -90,7 +90,6 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -90,7 +90,6 @@ class _GroupedLinear(torch.autograd.Function):
module
,
module
,
skip_fp8_weight_update
,
skip_fp8_weight_update
,
save_original_input
,
save_original_input
,
fine_grained_activation_offloading
,
)
=
non_tensor_args
)
=
non_tensor_args
num_gemms
=
len
(
m_splits
)
num_gemms
=
len
(
m_splits
)
...
@@ -225,16 +224,6 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -225,16 +224,6 @@ class _GroupedLinear(torch.autograd.Function):
else
:
else
:
inputmats
=
[
None
]
*
num_gemms
inputmats
=
[
None
]
*
num_gemms
for
i
in
range
(
num_gemms
):
weights
[
i
].
offloading_activation
=
False
weights_fp8
[
i
].
offloading_activation
=
False
biases
[
i
].
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
fine_grained_activation_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if
cpu_offloading
:
if
cpu_offloading
:
ctx
.
grad_added_to_main_grad
=
hasattr
(
weights
[
0
],
"grad_added_to_main_grad"
)
ctx
.
grad_added_to_main_grad
=
hasattr
(
weights
[
0
],
"grad_added_to_main_grad"
)
...
@@ -247,21 +236,6 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -247,21 +236,6 @@ class _GroupedLinear(torch.autograd.Function):
ctx
.
weight_objects
=
[]
ctx
.
weight_objects
=
[]
for
weight
in
weights
:
for
weight
in
weights
:
ctx
.
weight_objects
.
append
(
weight
)
ctx
.
weight_objects
.
append
(
weight
)
if
(
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
):
grad_added_to_main_grad_list
=
[]
ctx
.
grad_added_to_main_grad
=
hasattr
(
weights
[
0
],
"grad_added_to_main_grad"
)
for
weight
in
weights
:
if
ctx
.
grad_added_to_main_grad
:
grad_added_to_main_grad_list
.
append
(
weight
.
grad_added_to_main_grad
)
weight
.
grad_added_to_main_grad
=
True
ctx
.
weight_objects
.
append
(
weight
)
else
:
grad_added_to_main_grad_list
.
append
(
None
)
ctx
.
grad_added_to_main_grad_list
=
grad_added_to_main_grad_list
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
inputmats
,
*
inputmats
,
...
@@ -325,15 +299,12 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -325,15 +299,12 @@ class _GroupedLinear(torch.autograd.Function):
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad_funcs
]
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad_funcs
]
if
ctx
.
cpu_offloading
or
ctx
.
fine_grained_activation_offloading
:
if
ctx
.
cpu_offloading
:
if
ctx
.
grad_added_to_main_grad
:
if
ctx
.
grad_added_to_main_grad
:
for
i
,
weight
in
enumerate
(
ctx
.
weight_objects
):
for
i
,
weight
in
enumerate
(
ctx
.
weight_objects
):
origin_weights
[
i
]
=
ctx
.
weight_objects
[
i
]
origin_weights
[
i
]
=
ctx
.
weight_objects
[
i
]
ctx
.
weight_objects
[
i
]
=
None
ctx
.
weight_objects
[
i
]
=
None
if
ctx
.
fine_grained_activation_offloading
:
origin_weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
if
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
N
):
for
i
in
range
(
N
):
origin_weights
[
i
].
main_grad
=
main_grads
[
i
]
origin_weights
[
i
].
main_grad
=
main_grads
[
i
]
...
@@ -614,7 +585,6 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -614,7 +585,6 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs
:
bool
=
False
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
ub_name
:
Optional
[
str
]
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
save_original_input
:
bool
=
False
,
save_original_input
:
bool
=
False
,
)
->
None
:
)
->
None
:
...
@@ -637,7 +607,6 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -637,7 +607,6 @@ class GroupedLinear(TransformerEngineBaseModule):
),
"GroupedLinear doesn't support Userbuffer overlap."
),
"GroupedLinear doesn't support Userbuffer overlap."
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
rng_tracker_name
=
rng_tracker_name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
...
@@ -850,7 +819,6 @@ class GroupedLinear(TransformerEngineBaseModule):
...
@@ -850,7 +819,6 @@ class GroupedLinear(TransformerEngineBaseModule):
self
,
self
,
None
,
# skip_fp8_weight_update
None
,
# skip_fp8_weight_update
self
.
save_original_input
,
self
.
save_original_input
,
self
.
fine_grained_activation_offloading
,
)
)
out
=
linear_fn
(
*
autograd_ctx
,
inp
,
non_tensor_args
,
*
weight_tensors
,
*
bias_tensors
)
out
=
linear_fn
(
*
autograd_ctx
,
inp
,
non_tensor_args
,
*
weight_tensors
,
*
bias_tensors
)
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
96a104d5
...
@@ -40,7 +40,6 @@ from ..utils import (
...
@@ -40,7 +40,6 @@ from ..utils import (
requires_grad
,
requires_grad
,
needs_quantized_gemm
,
needs_quantized_gemm
,
get_nvtx_range_context
,
get_nvtx_range_context
,
get_activation_offloading
,
)
)
from
..distributed
import
(
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
set_tensor_model_parallel_attributes
,
...
@@ -144,7 +143,6 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -144,7 +143,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_wgrad
,
ub_bulk_wgrad
,
ub_bulk_dgrad
,
ub_bulk_dgrad
,
ub_name
,
ub_name
,
fine_grained_activation_offloading
,
fsdp_group
,
fsdp_group
,
module
,
module
,
skip_fp8_weight_update
,
skip_fp8_weight_update
,
...
@@ -598,11 +596,10 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -598,11 +596,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
# we need to connect them into one.
if
ctx
.
cpu_offloading
or
ctx
.
fine_grained_activation_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_
grad_added_to_main_grad
:
if
ctx
.
grad_added_to_main_grad
:
origin_weight
=
ctx
.
weight_object
origin_weight
=
ctx
.
weight_object
if
ctx
.
fine_grained_activation_offloading
:
origin_weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
origin_weight
.
main_grad
=
main_grad
origin_weight
.
main_grad
=
main_grad
...
@@ -1180,7 +1177,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1180,7 +1177,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
delay_wgrad_compute
:
bool
=
False
,
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
name
:
str
=
None
,
name
:
str
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -1199,7 +1195,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1199,7 +1195,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
)
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
name
=
name
self
.
name
=
name
...
@@ -1600,7 +1595,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1600,7 +1595,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
ub_bulk_wgrad
,
self
.
ub_bulk_wgrad
,
self
.
ub_bulk_dgrad
,
self
.
ub_bulk_dgrad
,
self
.
ub_name
,
self
.
ub_name
,
self
.
fine_grained_activation_offloading
,
self
.
fsdp_group
,
self
.
fsdp_group
,
self
,
self
,
skip_fp8_weight_update
,
skip_fp8_weight_update
,
...
...
transformer_engine/pytorch/module/linear.py
View file @
96a104d5
...
@@ -39,7 +39,6 @@ from ..utils import (
...
@@ -39,7 +39,6 @@ from ..utils import (
nvtx_range_pop
,
nvtx_range_pop
,
nvtx_range_push
,
nvtx_range_push
,
get_nvtx_range_context
,
get_nvtx_range_context
,
get_activation_offloading
,
)
)
from
..distributed
import
(
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
set_tensor_model_parallel_attributes
,
...
@@ -123,7 +122,6 @@ class _Linear(torch.autograd.Function):
...
@@ -123,7 +122,6 @@ class _Linear(torch.autograd.Function):
ub_bulk_dgrad
,
ub_bulk_dgrad
,
ub_bulk_wgrad
,
ub_bulk_wgrad
,
ub_name
,
ub_name
,
fine_grained_activation_offloading
,
fp8_output
,
# pylint: disable=unused-variable
fp8_output
,
# pylint: disable=unused-variable
fsdp_group
,
fsdp_group
,
module
,
module
,
...
@@ -420,30 +418,10 @@ class _Linear(torch.autograd.Function):
...
@@ -420,30 +418,10 @@ class _Linear(torch.autograd.Function):
)
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
fine_grained_activation_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if
(
fine_grained_activation_offloading
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
):
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
weight
.
grad_added_to_main_grad
=
True
ctx
.
weight_object
=
weight
else
:
ctx
.
has_grad_added_to_main_grad
=
False
if
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
ctx
.
has_
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
ctx
.
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
has_
grad_added_to_main_grad
:
if
ctx
.
grad_added_to_main_grad
:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# You need to preserve the weight object to have all the attributes user
...
@@ -540,11 +518,10 @@ class _Linear(torch.autograd.Function):
...
@@ -540,11 +518,10 @@ class _Linear(torch.autograd.Function):
else
None
else
None
)
)
if
ctx
.
cpu_offloading
or
ctx
.
fine_grained_activation_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
if
ctx
.
has_
grad_added_to_main_grad
:
if
ctx
.
grad_added_to_main_grad
:
weight
=
ctx
.
weight_object
weight
=
ctx
.
weight_object
if
ctx
.
fine_grained_activation_offloading
:
weight
.
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
requires_wgrad
and
ctx
.
fuse_wgrad_accumulation
:
weight
.
main_grad
=
main_grad
weight
.
main_grad
=
main_grad
...
@@ -1124,7 +1101,6 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1124,7 +1101,6 @@ class Linear(TransformerEngineBaseModule):
symmetric_ar_type
:
Optional
[
str
]
=
None
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
save_original_input
:
bool
=
False
,
save_original_input
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
,
fine_grained_activation_offloading
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -1140,7 +1116,6 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1140,7 +1116,6 @@ class Linear(TransformerEngineBaseModule):
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
save_original_input
=
save_original_input
self
.
save_original_input
=
save_original_input
self
.
name
=
name
self
.
name
=
name
self
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
,
ub_bulk_wgrad
)
...
@@ -1487,7 +1462,6 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1487,7 +1462,6 @@ class Linear(TransformerEngineBaseModule):
self
.
ub_bulk_dgrad
,
self
.
ub_bulk_dgrad
,
self
.
ub_bulk_wgrad
,
self
.
ub_bulk_wgrad
,
self
.
ub_name
,
self
.
ub_name
,
self
.
fine_grained_activation_offloading
,
fp8_output
,
fp8_output
,
self
.
fsdp_group
,
self
.
fsdp_group
,
self
,
self
,
...
...
transformer_engine/pytorch/utils.py
View file @
96a104d5
...
@@ -824,30 +824,3 @@ def make_weak_ref(x):
...
@@ -824,30 +824,3 @@ def make_weak_ref(x):
if
x
is
None
:
if
x
is
None
:
return
None
return
None
raise
TypeError
(
f
"Invalid type
{
type
(
x
)
}
to make weak ref"
)
raise
TypeError
(
f
"Invalid type
{
type
(
x
)
}
to make weak ref"
)
ActivationOffloadEnabled
=
False
def
get_activation_offloading
():
global
ActivationOffloadEnabled
return
ActivationOffloadEnabled
def
set_activation_offloading
(
activation_offloading
):
global
ActivationOffloadEnabled
ActivationOffloadEnabled
=
activation_offloading
class
ActivationOffloadContextManager
:
"""A reusable context manager for switch ActivationOffloadEnabled"""
def
__init__
(
self
,
activation_offloading
):
self
.
activation_offloading
=
activation_offloading
def
__enter__
(
self
):
self
.
origin_cpu_offloading
=
get_activation_offloading
()
set_activation_offloading
(
self
.
activation_offloading
)
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
set_activation_offloading
(
self
.
origin_cpu_offloading
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment