Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87e3e56e
Commit
87e3e56e
authored
Aug 27, 2025
by
yuguo
Browse files
Merge commit '
734bcedd
' of...
Merge commit '
734bcedd
' of
https://github.com/NVIDIA/TransformerEngine
parents
2f11bd2e
734bcedd
Changes
217
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
486 additions
and
358 deletions
+486
-358
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+30
-20
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+19
-23
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+131
-77
transformer_engine/pytorch/ops/linear.py
transformer_engine/pytorch/ops/linear.py
+69
-29
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+135
-143
transformer_engine/pytorch/ops/sequential.py
transformer_engine/pytorch/ops/sequential.py
+19
-32
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
.../pytorch/tensor/_internal/float8_blockwise_tensor_base.py
+9
-3
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
...mer_engine/pytorch/tensor/_internal/float8_tensor_base.py
+9
-3
transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
...rmer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
+9
-3
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+1
-1
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+16
-5
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+3
-3
transformer_engine/pytorch/tensor/quantized_tensor.py
transformer_engine/pytorch/tensor/quantized_tensor.py
+4
-0
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+22
-10
transformer_engine/pytorch/triton/cross_entropy.py
transformer_engine/pytorch/triton/cross_entropy.py
+7
-3
transformer_engine/pytorch/triton/permutation.py
transformer_engine/pytorch/triton/permutation.py
+1
-1
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+2
-2
No files found.
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
87e3e56e
...
...
@@ -10,9 +10,9 @@ import warnings
import
torch
from
transformer_engine_torch
import
CommOverlapType
from
transformer_engine_torch
import
CommOverlapType
,
bulk_overlap_ag_with_external_gemm
from
...cpp_extensions
import
general_gemm
from
...distributed
import
gather_along_first_dim
,
get_distributed_world_size
from
...distributed
import
get_distributed_world_size
from
...module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_ub
,
...
...
@@ -48,14 +48,14 @@ class UserbuffersBackwardLinear(FusedOperation):
# Basic operations that comprise this fused operation
op_idxs
=
{
"linear"
:
None
,
"bias"
:
None
,
"reduce_scatter"
:
None
}
ops
=
[]
if
reduce_scatter
is
not
None
:
op_idxs
[
"reduce_scatter"
]
=
len
(
ops
)
ops
.
append
(
reduce_scatter
)
op_idxs
[
"linear"
]
=
len
(
ops
)
ops
.
append
(
linear
)
if
bias
is
not
None
:
op_idxs
[
"bias"
]
=
len
(
ops
)
ops
.
append
(
bias
)
op_idxs
[
"linear"
]
=
len
(
ops
)
ops
.
append
(
linear
)
if
reduce_scatter
is
not
None
:
op_idxs
[
"reduce_scatter"
]
=
len
(
ops
)
ops
.
append
(
reduce_scatter
)
# Initialize base class
super
().
__init__
(
ops
)
...
...
@@ -398,26 +398,35 @@ class UserbuffersBackwardLinear(FusedOperation):
# Initialize grad output
if
tensor_parallel_mode
==
"row"
and
isinstance
(
grad_output_quantizer
,
MXFP8Quantizer
):
# UB does not support overlapping grad output
# UB does not support
pipelined
overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
# overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream
,
dgrad_recv_stream
=
ub_comm_dgrad
.
get_communication_stream
()
ub_obj_overlap_wgrad
=
get_ub
(
ub_comm_name
+
"_wgrad"
)
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream
=
ub_comm_dgrad
.
get_communication_stream
()
with
torch
.
cuda
.
stream
(
dgrad_comm_stream
):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
dy
,
dy_work
=
gather_along_first_dim
(
# We use the send stream to copy into the userbuffers.
# This is the same stream that we will use to access the data in the AG,
# so we dont need to add any syncs yet.
with
torch
.
cuda
.
stream
(
dgrad_send_stream
):
dy
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj_overlap_wgrad
,
dy_local
,
grad_output_quantizer
,
tensor_parallel_group
,
async_op
=
True
,
quantizer
=
grad_output_quantizer
,
)
# Synchronize with the main stream
dy_work
.
wait
()
# Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
bulk_overlap_ag_with_external_gemm
(
ub_obj_overlap_wgrad
,
dgrad_send_stream
,
dgrad_recv_stream
)
if
tensor_parallel_mode
==
"column"
:
dy
=
dy_local
...
...
@@ -495,7 +504,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations
idx
=
self
.
_op_idxs
[
"linear"
]
linear_op
=
self
.
basic_ops
[
idx
]
linear_op_ctx
=
basic_op_ctxs
[
idx
]
linear_op_ctx
=
basic_op_ctxs
[
-
1
]
bias_op
=
None
if
self
.
_op_idxs
[
"bias"
]
is
not
None
:
idx
=
self
.
_op_idxs
[
"bias"
]
...
...
@@ -556,6 +565,7 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params
[
self
.
_op_idxs
[
"linear"
]]
=
(
grad_weight
,)
if
bias_op
is
not
None
:
grad_params
[
self
.
_op_idxs
[
"bias"
]]
=
(
grad_bias
,)
grad_params
.
reverse
()
grad_extra_inputs
=
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
return
grad_input
,
grad_params
,
grad_extra_inputs
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
87e3e56e
...
...
@@ -182,7 +182,7 @@ class UserbuffersForwardLinear(FusedOperation):
if
weight_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for weight tensor"
)
if
output_quantizer
is
not
None
:
raise
ValueError
(
"
FP8
output is not supported"
)
raise
ValueError
(
"
Quantized
output is not supported"
)
else
:
input_quantizer
=
None
weight_quantizer
=
None
...
...
@@ -282,7 +282,7 @@ class UserbuffersForwardLinear(FusedOperation):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
...
...
@@ -307,21 +307,17 @@ class UserbuffersForwardLinear(FusedOperation):
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# Quantization metadata
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_output_quantizer
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_quantizer
=
None
weight_quantizer
=
None
grad_output_quantizer
=
None
grad_input_quantizer
=
None
if
with_quantized_compute
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
not
any
((
recipe
.
delayed
(),
recipe
.
float8_current_scaling
(),
recipe
.
mxfp8
())):
raise
RuntimeError
(
f
"Unsupported recipe for Userbuffers (
{
recipe
.
__class__
.
__name__
}
)"
)
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_input_quantizer
# Get autocast dtype if needed
if
torch
.
is_autocast_enabled
():
...
...
@@ -356,6 +352,7 @@ class UserbuffersForwardLinear(FusedOperation):
w
=
extra_outputs
[
"weight"
]
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
...
...
@@ -366,9 +363,8 @@ class UserbuffersForwardLinear(FusedOperation):
linear_op_ctx
.
input_dims
=
input_
.
size
()
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight_requires_grad
if
bias_op
is
not
None
:
bias_op_ctx
.
with_quantized_compute
=
with_quantized_compute
bias_op_ctx
.
grad_input_quantizer
=
linear_op
.
get_grad_input_quantizer
()
if
bias_op
is
not
None
and
bias_op_ctx
.
requires_grad
:
bias_op_ctx
.
grad_input_quantizer
=
linear_op
.
get_grad_output_quantizer
()
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
87e3e56e
...
...
@@ -5,22 +5,25 @@
"""Manager class for a pipeline of fusible operations."""
from
__future__
import
annotations
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Iterable
from
typing
import
Any
,
Optional
import
itertools
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
,
Recipe
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
,
Recipe
,
DelayedScaling
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
FusibleOperation
,
OperationContext
,
)
from
transformer_engine.pytorch.ops.fused
import
(
fuse_backward_
bias_
activation
,
fuse_backward_activation
_bias
,
fuse_backward_linear_add
,
fuse_backward_linear_scale
,
fuse_forward_linear_bias_activation
,
fuse_forward_linear_bias_add
,
fuse_forward_linear_scale_add
,
fuse_userbuffers_backward_linear
,
fuse_userbuffers_forward_linear
,
)
...
...
@@ -68,8 +71,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
input_
:
torch
.
Tensor
,
fuser
:
OperationFuser
,
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
is_grad_enabled
:
bool
,
*
params_and_extra_inputs
:
torch
.
nn
.
Parameter
,
*
params_and_extra_inputs
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
"""Forward pass
...
...
@@ -83,8 +85,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
Container for the pipeline of operations to run
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
is_grad_enabled: bool
Should context be saved for backward
*params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs.
...
...
@@ -103,10 +103,10 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Mark input tensors as not deletable in backward
for
tensor
in
(
input_
,)
+
params_and_extra_inputs
:
tensor
.
do_not_clear
=
True
tensor
.
_
do_not_clear
=
True
# Unflatten list of parameters and extra tensor inputs
extra_inputs
=
params_and_extra_inputs
[
-
fuser
.
_
num_extra_inputs
:]
extra_inputs
=
params_and_extra_inputs
[
-
fuser
.
num_extra_inputs
:]
basic_op_extra_inputs
=
[]
for
op
in
fuser
.
_basic_ops
:
xs
,
extra_inputs
=
_split_tuple
(
extra_inputs
,
op
.
num_extra_inputs
)
...
...
@@ -114,44 +114,37 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Apply forward ops
x
=
input_
requires_grad
=
is_grad_enabled
and
x
.
requires_grad
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
extra_outputs
=
[
None
]
*
fuser
.
_num_basic_ops
for
op
,
basic_op_idxs
in
fuser
.
_forward_ops
:
# Check if backward op is required
if
is_grad_enabled
:
if
not
requires_grad
:
requires_grad
=
any
(
param
.
requires_grad
for
param
in
op
.
parameters
())
if
not
requires_grad
:
requires_grad
=
any
(
any
(
x
.
requires_grad
for
x
in
xs
)
for
xs
in
extra_inputs
)
# Set if backward op is required
for
idx
in
basic_op_idxs
:
basic_op_ctxs
[
idx
].
requires_grad
=
requires_gra
d
basic_op_ctxs
[
idx
].
requires_grad
=
idx
>=
fuser
.
first_op_requiring_backwar
d
# Forward op
extra_inputs
=
[
basic_op_extra_inputs
[
idx
]
for
idx
in
basic_op_idxs
]
prev_op_idx
=
basic_op_idxs
[
0
]
-
1
prev_op
=
fuser
.
_basic_ops
[
prev_op_idx
]
if
prev_op_idx
>=
0
else
None
prev_op_grad_
in
put_quantizer
=
None
if
prev_op
is
not
None
and
with_quantized_compute
:
prev_op_grad_
in
put_quantizer
=
prev_op
.
get_grad_
in
put_quantizer
()
prev_op_grad_
out
put_quantizer
=
None
if
prev_op
is
not
None
:
prev_op_grad_
out
put_quantizer
=
prev_op
.
get_grad_
out
put_quantizer
()
next_op_idx
=
basic_op_idxs
[
-
1
]
+
1
next_op
=
fuser
.
_basic_ops
[
next_op_idx
]
if
next_op_idx
<
fuser
.
_num_basic_ops
else
None
next_op_input_quantizer
=
None
if
next_op
is
not
None
and
with_quantized_compute
:
if
next_op
is
not
None
:
next_op_input_quantizer
=
next_op
.
get_input_quantizer
()
x
,
fused_op_extra_outputs
=
op
.
fuser_forward
(
[
basic_op_ctxs
[
idx
]
for
idx
in
basic_op_idxs
],
x
,
basic_op_extra_inputs
=
extra_inputs
,
prev_op_grad_
in
put_quantizer
=
prev_op_grad_
in
put_quantizer
,
prev_op_grad_
out
put_quantizer
=
prev_op_grad_
out
put_quantizer
,
next_op_input_quantizer
=
next_op_input_quantizer
,
basic_op_kwargs
=
[
basic_op_kwargs
[
idx
]
for
idx
in
basic_op_idxs
],
)
for
idx
,
ys
in
zip
(
basic_op_idxs
,
fused_op_extra_outputs
):
for
y
in
ys
:
y
.
requires_grad_
(
requires_gra
d
)
y
.
requires_grad_
(
idx
>=
fuser
.
first_op_requiring_backwar
d
)
extra_outputs
[
idx
]
=
ys
# Flatten list of extra outputs
...
...
@@ -168,7 +161,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat
.
extend
(
ys
)
# Save context for backward pass
if
is_grad_enabled
:
if
func_ctx
is
not
None
:
# Flatten list of saved tensors
to_save
=
[]
...
...
@@ -181,24 +174,29 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
ctx
.
_saved_tensors_range
=
(
range_start
,
range_end
)
# Save tensors for backward
if
with_quantized_compute
:
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
to_save
)
func_ctx
.
save_for_backward
(
*
tensors_to_save
)
func_ctx
.
tensor_objects
=
tensor_objects
else
:
func_ctx
.
save_for_backward
(
*
to_save
)
# Whether to perform recipe update in backward pass
is_first_module
=
False
if
fuser
.
first_op_requiring_backward
<
fuser
.
_num_basic_ops
:
is_first_module
=
FP8GlobalStateManager
.
is_first_fp8_module
()
# Other context
func_ctx
.
backward_ops
=
fuser
.
_backward_ops
func_ctx
.
basic_ops
=
fuser
.
_basic_ops
func_ctx
.
basic_op_ctxs
=
basic_op_ctxs
func_ctx
.
basic_op_num_params
=
fuser
.
_
num_list_
basic_op_params
func_ctx
.
num_extra_inputs
=
fuser
.
_
num_extra_inputs
func_ctx
.
basic_op_num_params
=
fuser
.
_basic_op_
num_
params
func_ctx
.
num_extra_inputs
=
fuser
.
num_extra_inputs
func_ctx
.
num_extra_outputs
=
len
(
extra_outputs_flat
)
func_ctx
.
is_first_module
=
FP8GlobalStateManager
.
is_first_fp8_module
()
func_ctx
.
with_quantized_compute
=
with_quantized_compute
func_ctx
.
is_first_module
=
is_first_module
x
.
requires_grad_
(
requires_grad
)
# Mark output tensors as not deletable in backward
for
tensor
in
[
x
]
+
extra_outputs_flat
:
tensor
.
_do_not_clear
=
True
x
.
requires_grad_
(
fuser
.
first_op_requiring_backward
<
fuser
.
_num_basic_ops
)
if
extra_outputs_flat
:
return
x
,
*
extra_outputs_flat
...
...
@@ -220,10 +218,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs
=
func_ctx
.
basic_op_ctxs
# Restore saved tensors
if
func_ctx
.
with_quantized_compute
:
saved_tensors
=
restore_from_saved
(
func_ctx
.
tensor_objects
,
func_ctx
.
saved_tensors
)
else
:
saved_tensors
=
func_ctx
.
saved_tensors
# Unflatten list of saved tensors
for
ctx
in
basic_op_ctxs
:
...
...
@@ -304,7 +299,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx
,
# input_
None
,
# fuser
None
,
# basic_op_kwargs
None
,
# is_grad_enabled
*
grad_params_flat
,
*
grad_extra_inputs_flat
,
)
...
...
@@ -317,19 +311,12 @@ class OperationFuser:
----------
ops: list of FusibleOperation
Pipeline of operations
fuse_ops: bool
Whether to attempt fusing operations
recipe: Recipe, optional
Quantization recipe to use when fusing and executing operations.
Note: certain fusions may depend on what kind of recipe is being used.
"""
def
__init__
(
self
,
ops
:
list
[
FusibleOperation
],
fuse_ops
:
bool
,
recipe
:
Optional
[
Recipe
],
)
->
None
:
# Get list of basic operations
...
...
@@ -343,25 +330,22 @@ class OperationFuser:
self
.
_basic_ops
:
list
[
BasicOperation
]
=
basic_ops
# Number of extra tensor inputs
self
.
_num_extra_inputs
:
int
=
sum
(
op
.
num_extra_inputs
for
op
in
basic_ops
)
self
.
_basic_op_num_extra_inputs
:
list
[
int
]
=
list
(
op
.
num_extra_inputs
for
op
in
basic_ops
)
self
.
num_extra_inputs
:
int
=
sum
(
self
.
_basic_op_num_extra_inputs
)
# Ops for forward and backward pass
# Ops for forward and backward pass
, will be populated in fuse_ops
self
.
_forward_ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]
self
.
_backward_ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]
self
.
_forward_ops
=
[(
op
,
(
idx
,))
for
idx
,
op
in
enumerate
(
self
.
_basic_ops
)]
self
.
_backward_ops
=
list
(
reversed
(
self
.
_forward_ops
))
# Flag for checking if this is the first iteration
self
.
_is_first_forward
=
True
#
Fuse ops if needed
self
.
recipe
=
recip
e
if
fuse_ops
:
self
.
fuse_ops
()
#
Cache and detect change of state relevant for fusing operations
self
.
recipe
_type
=
Non
e
self
.
first_op_requiring_backward
=
0
self
.
_last_amax_history_len
=
0
# Flatten list of parameters
self
.
_basic_op_params
=
[
param
for
op
in
self
.
_basic_ops
for
param
in
op
.
parameters
()]
self
.
_num_list_basic_op_params
=
[
sum
(
1
for
_
in
op
.
parameters
())
for
op
in
self
.
_basic_ops
]
self
.
_basic_op_params
=
[
list
(
op
.
parameters
())
for
op
in
self
.
_basic_ops
]
self
.
_basic_op_num_params
=
list
(
map
(
len
,
self
.
_basic_op_params
))
self
.
_flat_basic_op_params
=
sum
(
self
.
_basic_op_params
,
[])
@
classmethod
def
_fuse_forward_ops
(
...
...
@@ -373,6 +357,7 @@ class OperationFuser:
ops
=
fuse_userbuffers_forward_linear
(
ops
)
ops
=
fuse_forward_linear_bias_add
(
ops
)
ops
=
fuse_forward_linear_bias_activation
(
ops
)
ops
=
fuse_forward_linear_scale_add
(
ops
)
return
ops
@
classmethod
...
...
@@ -384,13 +369,74 @@ class OperationFuser:
"""Attempt to fuse operations in backward pass"""
ops
=
fuse_userbuffers_backward_linear
(
ops
)
ops
=
fuse_backward_linear_add
(
ops
)
ops
=
fuse_backward_bias_activation
(
ops
,
recipe
)
ops
=
fuse_backward_linear_scale
(
ops
)
ops
=
fuse_backward_activation_bias
(
ops
,
recipe
)
return
ops
def
fuse_ops
(
self
)
->
None
:
"""Attempt to fuse operations"""
self
.
_forward_ops
=
self
.
_fuse_forward_ops
(
self
.
_forward_ops
,
self
.
recipe
)
self
.
_backward_ops
=
self
.
_fuse_backward_ops
(
self
.
_backward_ops
,
self
.
recipe
)
def
maybe_fuse_ops
(
self
,
is_grad_enabled
:
bool
,
recipe
:
Optional
[
Recipe
],
input_
:
torch
.
Tensor
,
extra_inputs
:
list
[
Iterable
[
torch
.
Tensor
]],
):
"""Attempt to fuse operations if neccesary"""
# Determine which basic ops require backward
if
not
is_grad_enabled
:
first_op_requiring_backward
=
self
.
_num_basic_ops
elif
input_
.
requires_grad
:
first_op_requiring_backward
=
0
else
:
first_op_requiring_backward
=
self
.
_num_basic_ops
for
op_idx
in
range
(
self
.
_num_basic_ops
):
op_inputs
=
itertools
.
chain
(
self
.
_basic_op_params
[
op_idx
],
extra_inputs
[
op_idx
])
if
any
(
tensor
.
requires_grad
for
tensor
in
op_inputs
):
first_op_requiring_backward
=
op_idx
break
# Early exit if fusion parameters haven't changed
need_reset
=
False
recipe_type
=
type
(
recipe
)
fusion_params
=
(
recipe_type
,
first_op_requiring_backward
)
if
fusion_params
!=
(
self
.
recipe_type
,
self
.
first_op_requiring_backward
):
# Recipe type or grad requirmenets have changed
need_reset
=
True
elif
(
recipe
is
not
None
and
recipe
.
delayed
()
and
self
.
_last_amax_history_len
!=
recipe
.
amax_history_len
):
# FP8 delayed scaling has changed amax history length
need_reset
=
True
if
not
need_reset
:
return
# Reset recipe state
for
op
in
self
.
_basic_ops
:
op
.
reset_recipe_state
(
recipe
=
recipe
)
# Check if this is the first iteration
if
self
.
recipe_type
is
None
:
for
op
in
self
.
_basic_ops
:
op
.
pre_first_fuser_forward
()
# Prepare basic op lists for fusions
forward_ops
=
[(
op
,
[
idx
])
for
idx
,
op
in
enumerate
(
self
.
_basic_ops
)]
backward_ops
=
list
(
reversed
(
forward_ops
[
first_op_requiring_backward
:]))
# Fuse ops
self
.
_forward_ops
=
self
.
_fuse_forward_ops
(
forward_ops
,
recipe
)
self
.
_backward_ops
=
self
.
_fuse_backward_ops
(
backward_ops
,
recipe
)
# Save current fusion params
self
.
recipe_type
,
self
.
first_op_requiring_backward
=
fusion_params
# Save amax history length
if
isinstance
(
recipe
,
DelayedScaling
):
self
.
_last_amax_history_len
=
recipe
.
amax_history_len
else
:
self
.
_last_amax_history_len
=
0
def
__call__
(
self
,
...
...
@@ -399,23 +445,32 @@ class OperationFuser:
basic_op_kwargs
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
# Verify extra input count
if
len
(
extra_inputs
)
!=
self
.
_
num_extra_inputs
:
if
len
(
extra_inputs
)
!=
self
.
num_extra_inputs
:
raise
ValueError
(
f
"Expected
{
self
.
_
num_extra_inputs
}
extra inputs but got
{
len
(
extra_inputs
)
}
"
f
"Expected
{
self
.
num_extra_inputs
}
extra inputs but got
{
len
(
extra_inputs
)
}
"
)
# Initialization before forward pass
if
self
.
_is_first_forward
:
for
op
in
self
.
_basic_ops
:
op
.
pre_first_forward
(
recipe
=
self
.
recipe
)
self
.
_is_first_forward
=
False
# Canonicalize op kwargs
if
basic_op_kwargs
is
None
:
basic_op_kwargs
=
[{}]
*
self
.
_num_basic_ops
# Fuser forward pass
# Unflatten list of extra tensor inputs
extra_inputs_copy
=
list
(
extra_inputs
)
basic_op_extra_inputs
=
[]
for
op
in
self
.
_basic_ops
:
xs
,
extra_inputs_copy
=
_split_tuple
(
extra_inputs_copy
,
op
.
num_extra_inputs
)
basic_op_extra_inputs
.
append
(
xs
)
# Get environment state
recipe
=
None
if
FP8GlobalStateManager
.
is_fp8_enabled
():
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
is_grad_enabled
=
torch
.
is_grad_enabled
()
# Attempt to fuse operations if neccesary
self
.
maybe_fuse_ops
(
is_grad_enabled
,
recipe
,
input
,
basic_op_extra_inputs
)
# Fuser forward pass
if
is_grad_enabled
:
forward_func
=
_OperationFuserAutogradFunction
.
apply
args
=
[]
...
...
@@ -426,8 +481,7 @@ class OperationFuser:
input
,
self
,
basic_op_kwargs
,
is_grad_enabled
,
*
self
.
_basic_op_params
,
*
self
.
_flat_basic_op_params
,
*
extra_inputs
,
)
return
forward_func
(
*
args
)
transformer_engine/pytorch/ops/linear.py
View file @
87e3e56e
...
...
@@ -6,7 +6,7 @@
from
__future__
import
annotations
from
collections.abc
import
Callable
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
torch
...
...
@@ -91,6 +91,8 @@ class Linear(FusedOperation):
# Construct basic ops
ops
=
[]
linear_idx
=
None
bias_idx
=
None
linear_kwargs
=
{
"in_features"
:
in_features
,
"out_features"
:
out_features
,
...
...
@@ -111,14 +113,16 @@ class Linear(FusedOperation):
}
if
tensor_parallel_mode
==
"row"
:
# Row TP: GEMM + bias + reduction
linear_idx
=
len
(
ops
)
linear_kwargs
[
"in_features"
]
=
local_in_features
linear_kwargs
[
"out_features"
]
=
local_out_features
linear_kwargs
[
"tensor_parallel_mode"
]
=
None
linear_kwargs
[
"tensor_parallel_group"
]
=
None
linear_kwargs
[
"sequence_parallel"
]
=
False
bias_kwargs
[
"size"
]
*=
tensor_parallel_size
ops
.
append
(
BasicLinear
(
**
linear_kwargs
))
if
bias
:
bias_idx
=
len
(
ops
)
bias_kwargs
[
"size"
]
*=
tensor_parallel_size
ops
.
append
(
Bias
(
**
bias_kwargs
))
if
sequence_parallel
:
ops
.
append
(
ReduceScatter
(
tensor_parallel_group
))
...
...
@@ -126,45 +130,81 @@ class Linear(FusedOperation):
ops
.
append
(
AllReduce
(
tensor_parallel_group
))
else
:
# Column TP or no TP: (gather + GEMM) + bias
linear_idx
=
len
(
ops
)
ops
.
append
(
BasicLinear
(
**
linear_kwargs
))
if
bias
:
bias_idx
=
len
(
ops
)
ops
.
append
(
Bias
(
**
bias_kwargs
))
# Initialize base class
super
().
__init__
(
ops
)
self
.
_has_bias
:
bool
=
bias
@
property
def
weight
(
self
)
->
torch
.
nn
.
Parameter
:
"""Weight tensor
Parameter is owned by `BasicLinear` operation.
"""
return
self
.
basic_ops
[
0
].
weight
@
weight
.
setter
def
weight
(
self
,
value
:
Optional
[
torch
.
nn
.
Parameter
])
->
None
:
self
.
basic_ops
[
0
].
weight
=
value
# Register parameters
self
.
_linear_idx
:
Optional
[
int
]
=
linear_idx
self
.
_bias_idx
:
Optional
[
int
]
=
bias_idx
self
.
register_parameter
(
"weight"
,
self
.
basic_ops
[
self
.
_linear_idx
].
weight
)
bias
=
None
if
self
.
_bias_idx
is
not
None
:
bias
=
self
.
basic_ops
[
self
.
_bias_idx
].
bias
self
.
register_parameter
(
"bias"
,
bias
)
@
property
def
bias
(
self
)
->
Optional
[
torch
.
nn
.
Parameter
]:
"""Bias tensor
def
register_parameter
(
self
,
name
:
str
,
param
:
Optional
[
torch
.
nn
.
Parameter
])
->
None
:
"""Add a parameter to the module
Parameter is owned by `Bias` operation
.
Also updates the basic operation that owns the parameter
.
"""
if
self
.
_has_bias
:
return
self
.
basic_ops
[
1
].
bias
return
None
@
bias
.
setter
def
bias
(
self
,
value
:
Optional
[
torch
.
nn
.
Parameter
])
->
None
:
if
self
.
_has_bias
:
self
.
basic_ops
[
1
].
bias
=
value
elif
value
is
not
None
:
if
name
==
"bias"
and
self
.
_bias_idx
is
None
and
param
is
not
None
:
raise
ValueError
(
"Attempted to set bias parameter in Linear operation "
"that does not have bias enabled"
)
super
().
register_parameter
(
name
,
param
)
if
name
==
"weight"
:
self
.
basic_ops
[
self
.
_linear_idx
].
weight
=
param
elif
name
==
"bias"
and
self
.
_bias_idx
is
not
None
:
self
.
basic_ops
[
self
.
_bias_idx
].
bias
=
param
def
state_dict
(
self
,
*
,
prefix
:
str
=
""
,
**
kwargs
)
->
dict
[
str
,
Any
]:
"""Save state"""
state_dict
=
super
().
state_dict
(
prefix
=
prefix
,
**
kwargs
)
# Remove basic op params from state dict
# Note: Logically, basic ops own params and fused ops are
# considered as stateless. However, we register weight and
# bias params in the linear op for convenience. We remove
# these redudant params from the checkpoint for backward
# compatibility.
if
f
"
{
prefix
}
weight"
in
state_dict
:
del
state_dict
[
f
"
{
prefix
}
weight"
]
if
f
"
{
prefix
}
bias"
in
state_dict
:
del
state_dict
[
f
"
{
prefix
}
bias"
]
return
state_dict
def
_load_from_state_dict
(
self
,
state_dict
:
dict
[
str
,
Any
],
prefix
:
str
,
*
args
,
**
kwargs
,
)
->
None
:
# Add basic op params to state dict
# Note: Logically, basic ops own params and fused ops are
# considered as stateless. However, we register weight and
# bias params in the linear op for convenience. We remove
# these redudant params from the checkpoint for backward
# compatibility.
if
f
"
{
prefix
}
weight"
not
in
state_dict
:
state_dict
[
f
"
{
prefix
}
weight"
]
=
state_dict
[
f
"
{
prefix
}
basic_ops.
{
self
.
_linear_idx
}
.weight"
]
if
f
"
{
prefix
}
bias"
not
in
state_dict
:
if
self
.
_bias_idx
is
None
:
state_dict
[
f
"
{
prefix
}
bias"
]
=
None
else
:
state_dict
[
f
"
{
prefix
}
bias"
]
=
state_dict
[
f
"
{
prefix
}
basic_ops.
{
self
.
_bias_idx
}
.bias"
]
# Load state dict
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
transformer_engine/pytorch/ops/op.py
View file @
87e3e56e
...
...
@@ -15,9 +15,6 @@ import torch
from
transformer_engine.common.recipe
import
Recipe
from
..fp8
import
(
MXFP8BlockScalingRecipeState
,
DelayedScalingRecipeState
,
Float8BlockScalingRecipeState
,
FP8GlobalStateManager
,
RecipeState
,
fp8_autocast
,
...
...
@@ -65,18 +62,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def
is_fused_op
(
self
)
->
bool
:
"""Whether this op is the fusion of one or more basic ops"""
def
pre_first_forward
(
self
,
*
,
recipe
:
Optional
[
Recipe
],
)
->
None
:
"""Preprocessing before forward pass"""
def
pre_first_fuser_forward
(
self
)
->
None
:
"""Preprocessing before first fuser forward pass"""
def
get_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
"""Get builder class for quantized input tensor"""
def
get_grad_
in
put_quantizer
(
self
)
->
Optional
[
Quantizer
]:
"""Get builder class for quantized
in
put's grad tensor"""
def
get_grad_
out
put_quantizer
(
self
)
->
Optional
[
Quantizer
]:
"""Get builder class for quantized
out
put's grad tensor"""
def
fuser_forward
(
self
,
...
...
@@ -84,7 +77,7 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
...
...
@@ -104,8 +97,8 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
Input tensor
basic_op_extra_inputs: list of torch.Tensor
Extra tensor inputs to basic operations
prev_op_grad_
in
put_quantizer: Quantizer, optional
The grad_
in
put_quantizer of the preceeding operation
prev_op_grad_
out
put_quantizer: Quantizer, optional
The grad_
out
put_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
basic_op_kwargs: list of dict
...
...
@@ -186,8 +179,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
super
().
__init__
()
# Objects for quantization
self
.
_quantizers
:
Optional
[
dict
[
str
,
list
[
Quantizer
]]]
=
None
self
.
_fp8_metas
:
Optional
[
dict
[
str
,
dict
[
str
,
Any
]]]
=
None
self
.
_quantizers
:
Optional
[
dict
[
str
,
list
[
Quantizer
]]]
=
None
with_fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_fp8_parameters
else
None
self
.
reset_recipe_state
(
recipe
=
recipe
)
@
property
def
is_fused_op
(
self
)
->
bool
:
...
...
@@ -214,19 +210,47 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
return
self
.
get_quantizer
(
"forward"
,
0
)
return
None
def
get_grad_
in
put_quantizer
(
self
)
->
Optional
[
Quantizer
]:
def
get_grad_
out
put_quantizer
(
self
)
->
Optional
[
Quantizer
]:
if
self
.
num_quantizers
(
"backward"
)
>
0
:
return
self
.
get_quantizer
(
"backward"
,
0
)
return
None
def
_
reset_
quantization_
recipe_state
(
def
reset_recipe_state
(
self
,
*
,
recipe
:
Recipe
,
recipe
:
Optional
[
Recipe
]
,
)
->
None
:
"""Construct state for quantization recipe"""
# Quantization recipe state for forward and backward pass
# Clear quantization state if necessary
if
recipe
is
None
:
self
.
_fp8_metas
=
None
self
.
_quantizers
=
None
return
# Communication group for FP8 amax reductions
fp8_group
=
FP8GlobalStateManager
.
get_fp8_group
()
# Skip resetting recipe type if it did not actually change.
# This could happen for example if calling BasicOperation.forward directly, as in that
# case, the OperationFuser is not persistent, or when loading from a checkpoint
need_to_reset_recipe_state
=
False
if
self
.
_fp8_metas
is
None
or
self
.
_quantizers
is
None
:
need_to_reset_recipe_state
=
True
else
:
for
mode
in
(
"forward"
,
"backward"
):
fp8_meta_key
=
FP8GlobalStateManager
.
get_meta_tensor_key
(
forward
=
(
mode
==
"forward"
),
)
if
self
.
_fp8_metas
[
mode
]
is
None
or
fp8_meta_key
not
in
self
.
_fp8_metas
[
mode
]:
continue
recipe_state
=
self
.
_fp8_metas
[
mode
][
fp8_meta_key
]
if
not
isinstance
(
recipe
,
type
(
recipe_state
.
recipe
)):
need_to_reset_recipe_state
=
True
break
if
need_to_reset_recipe_state
:
# Construct quantization recipe states
self
.
_fp8_metas
=
{
"forward"
:
None
,
"backward"
:
None
}
self
.
_quantizers
=
{
"forward"
:
[],
"backward"
:
[]}
for
mode
in
(
"forward"
,
"backward"
):
...
...
@@ -251,83 +275,76 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self
.
_fp8_metas
[
mode
]
=
{
fp8_meta_key
:
recipe_state
,
"recipe"
:
recipe
,
"fp8_group"
:
FP8GlobalStateManager
.
get_
fp8_group
()
,
"fp8_group"
:
fp8_group
,
}
# Construct builder class for quantized tensors
self
.
_quantizers
[
mode
]
=
recipe_state
.
make_quantizers
()
def
_update_quantization_recipe_state
(
self
,
*
,
recipe
:
Recipe
,
)
->
None
:
"""Make sure quantizer state matches quantization recipe"""
# Reset quantization state if needed
if
self
.
_fp8_metas
is
None
or
self
.
_quantizers
is
None
:
self
.
_reset_quantization_recipe_state
(
recipe
=
recipe
)
return
for
mode
in
(
"forward"
,
"backward"
):
fp8_meta_key
=
FP8GlobalStateManager
.
get_meta_tensor_key
(
forward
=
(
mode
==
"forward"
),
)
if
self
.
_fp8_metas
[
mode
]
is
None
or
fp8_meta_key
not
in
self
.
_fp8_metas
[
mode
]:
continue
recipe_state
=
self
.
_fp8_metas
[
mode
][
fp8_meta_key
]
need_to_reset_recipe_state
=
(
(
recipe
.
delayed
()
and
not
isinstance
(
recipe_state
,
DelayedScalingRecipeState
))
or
(
recipe
.
mxfp8
()
and
not
isinstance
(
recipe_state
,
MXFP8BlockScalingRecipeState
))
or
(
recipe
.
float8_block_scaling
()
and
not
isinstance
(
recipe_state
,
Float8BlockScalingRecipeState
)
)
)
if
need_to_reset_recipe_state
:
self
.
_reset_quantization_recipe_state
(
recipe
=
recipe
)
return
# Quantization recipe state for forward and backward pass
else
:
# Update quantization recipe states
for
mode
in
(
"forward"
,
"backward"
):
num_quantizers
=
self
.
num_quantizers
(
mode
)
if
num_quantizers
==
0
:
if
self
.
_fp8_metas
[
mode
]
is
None
:
continue
self
.
_fp8_metas
[
mode
][
"recipe"
]
=
recipe
self
.
_fp8_metas
[
mode
][
"fp8_group"
]
=
fp8_group
# Update FP8 metadata
fp8_meta
=
self
.
_fp8_metas
[
mode
]
fp8_meta
[
"recipe"
]
=
recipe
fp8_meta
[
"fp8_group"
]
=
FP8GlobalStateManager
.
get_fp8_group
()
# Get recipe state
# Update amax history for FP8 delayed scaling
if
recipe
.
delayed
():
fp8_meta_key
=
FP8GlobalStateManager
.
get_meta_tensor_key
(
forward
=
(
mode
==
"forward"
),
)
recipe_state
=
fp8_meta
[
fp8_meta_key
]
recipe_state
=
self
.
_
fp8_meta
s
[
mode
]
[
fp8_meta_key
]
# Reallocate amax history if needed
if
not
recipe
.
delayed
():
continue
current_length
=
recipe_state
.
amax_history
.
size
(
0
)
target_length
=
recipe
.
amax_history_len
if
current_length
!=
target_length
:
with
torch
.
no_grad
():
if
target_length
<
current_length
:
with
torch
.
no_grad
():
recipe_state
.
amax_history
=
recipe_state
.
amax_history
[
:
target_length
].
clone
()
else
:
elif
target_length
>
current_length
:
with
torch
.
no_grad
():
recipe_state
.
amax_history
=
torch
.
nn
.
functional
.
pad
(
recipe_state
.
amax_history
,
pad
=
(
0
,
0
,
0
,
target_length
-
current_length
),
)
# Update quantizers with new amax pointers
self
.
_quantizers
[
mode
]
=
recipe_state
.
make_quantizers
()
# Update the global buffers with new amax pointers
if
FP8GlobalStateManager
.
get_buffer_info
()
in
self
.
_fp8_metas
[
mode
]:
pos
,
buffer_key
=
self
.
_fp8_metas
[
mode
][
FP8GlobalStateManager
.
get_buffer_info
()
]
if
buffer_key
in
FP8GlobalStateManager
.
global_amax_buffer
:
assert
(
buffer_key
in
FP8GlobalStateManager
.
global_amax_history_buffer
),
"TE internal error during amax history change."
FP8GlobalStateManager
.
global_amax_buffer
[
buffer_key
][
pos
]
=
(
recipe_state
.
amax_history
[
0
]
)
FP8GlobalStateManager
.
global_amax_history_buffer
[
buffer_key
][
pos
]
=
recipe_state
.
amax_history
# Add meta tensors to global buffer to participate in reduction
for
mode
in
(
"forward"
,
"backward"
):
if
(
FP8GlobalStateManager
.
is_fp8_enabled
()
and
self
.
num_quantizers
(
mode
)
and
not
FP8GlobalStateManager
.
fp8_graph_capturing
()
):
FP8GlobalStateManager
.
add_fp8_tensors_to_global_buffer
(
self
.
_fp8_metas
[
mode
],
)
def
get_quantizer
(
self
,
mode
:
str
,
index
:
int
,
)
->
Quantizer
:
)
->
Optional
[
Quantizer
]
:
"""Get builder class for quantized tensor
Parameters
...
...
@@ -337,7 +354,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""
if
self
.
_quantizers
is
None
:
self
.
_reset_quantization_recipe_state
(
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
())
return
None
return
self
.
_quantizers
[
mode
][
index
]
@
torch
.
no_grad
()
...
...
@@ -388,33 +405,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self
.
_fp8_metas
[
mode
][
fp8_meta_key
].
scale
.
copy_
(
scale
)
self
.
_fp8_metas
[
mode
][
fp8_meta_key
].
amax_history
.
copy_
(
amax_history
)
def
pre_first_forward
(
self
,
*
,
recipe
:
Optional
[
Recipe
],
)
->
None
:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
if
recipe
is
not
None
:
self
.
_update_quantization_recipe_state
(
recipe
=
recipe
)
if
not
FP8GlobalStateManager
.
fp8_graph_capturing
():
if
self
.
num_quantizers
(
"forward"
):
FP8GlobalStateManager
.
add_fp8_tensors_to_global_buffer
(
self
.
_fp8_metas
[
"forward"
],
)
if
self
.
num_quantizers
(
"backward"
):
FP8GlobalStateManager
.
add_fp8_tensors_to_global_buffer
(
self
.
_fp8_metas
[
"backward"
],
)
@
abc
.
abstractmethod
def
op_forward
(
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
*
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
...
...
@@ -426,8 +423,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
Context to coordinate between forward and backward passes
input_: torch.Tensor
Input tensor
prev_op_grad_
in
put_quantizer: Quantizer, optional
The grad_
in
put_quantizer of the preceeding operation
prev_op_grad_
out
put_quantizer: Quantizer, optional
The grad_
out
put_quantizer of the preceeding operation
next_op_input_quantizer: Quantizer, optional
The input_quantizer of the following operation
...
...
@@ -468,7 +465,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
list
[
tuple
[()]]]:
...
...
@@ -482,7 +479,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
output
=
self
.
op_forward
(
basic_op_ctxs
[
0
],
input_
,
prev_op_grad_
in
put_quantizer
=
prev_op_grad_
in
put_quantizer
,
prev_op_grad_
out
put_quantizer
=
prev_op_grad_
out
put_quantizer
,
next_op_input_quantizer
=
next_op_input_quantizer
,
**
basic_op_kwargs
[
0
],
)
...
...
@@ -518,9 +515,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""Apply operation"""
from
.fuser
import
OperationFuser
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_quantized_compute
else
None
return
OperationFuser
([
self
],
fuse_ops
=
False
,
recipe
=
recipe
)(
return
OperationFuser
([
self
])(
input
,
*
extra_inputs
,
basic_op_kwargs
=
[
kwargs
],
...
...
@@ -630,7 +625,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
if
self
.
_fp8_metas
is
None
or
self
.
_fp8_metas
[
mode
]
is
None
:
with
fp8_autocast
(
fp8_recipe
=
state
[
mode
][
"recipe"
]):
self
.
_
reset_
quantization_
recipe_state
(
recipe
=
state
[
mode
][
"recipe"
])
self
.
reset_recipe_state
(
recipe
=
state
[
mode
][
"recipe"
])
fp8_meta
=
self
.
_fp8_metas
[
mode
]
# Load extra items
...
...
@@ -708,13 +703,12 @@ class FusedOperation(FusibleOperation):
def
get_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
return
self
.
basic_ops
[
0
].
get_input_quantizer
()
def
get_grad_
in
put_quantizer
(
self
)
->
Optional
[
Quantizer
]:
return
self
.
basic_ops
[
-
1
].
get_grad_
in
put_quantizer
()
def
get_grad_
out
put_quantizer
(
self
)
->
Optional
[
Quantizer
]:
return
self
.
basic_ops
[
-
1
].
get_grad_
out
put_quantizer
()
def
pre_first_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""Preprocessing before forward pass"""
def
pre_first_fuser_forward
(
self
)
->
None
:
for
op
in
self
.
basic_ops
:
op
.
pre_first_forward
(
*
args
,
**
kwargs
)
op
.
pre_first_
fuser_
forward
()
def
forward
(
self
,
...
...
@@ -727,9 +721,7 @@ class FusedOperation(FusibleOperation):
basic_op_kwargs
=
[{}
for
_
in
range
(
len
(
self
.
basic_ops
))]
from
.fuser
import
OperationFuser
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_quantized_compute
else
None
return
OperationFuser
([
self
],
fuse_ops
=
False
,
recipe
=
recipe
)(
return
OperationFuser
([
self
])(
input
,
*
extra_inputs
,
basic_op_kwargs
=
basic_op_kwargs
,
...
...
transformer_engine/pytorch/ops/sequential.py
View file @
87e3e56e
...
...
@@ -10,7 +10,6 @@ from typing import Optional
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
,
Recipe
from
transformer_engine.pytorch.ops.op
import
FusibleOperation
from
transformer_engine.pytorch.ops.fuser
import
OperationFuser
...
...
@@ -147,7 +146,6 @@ class Sequential(torch.nn.Module):
def
_make_module_groups
(
cls
,
modules
:
Iterable
[
torch
.
nn
.
Module
],
recipe
:
Optional
[
Recipe
],
)
->
list
[
OperationFuser
|
torch
.
nn
.
Module
]:
"""Make list of modules, with fusible operations grouped together"""
...
...
@@ -162,24 +160,7 @@ class Sequential(torch.nn.Module):
groups
.
append
(
module
)
for
idx
,
group
in
enumerate
(
groups
):
if
isinstance
(
group
,
list
):
groups
[
idx
]
=
OperationFuser
(
group
,
fuse_ops
=
True
,
recipe
=
recipe
)
# Check if operations expect extra input or output tensors
# Note: If any op has extra inputs or outputs, then the entire
# Sequential must be made up of TE ops.
if
len
(
groups
)
>
1
:
ops
=
[]
for
group
in
groups
:
if
isinstance
(
group
,
OperationFuser
):
ops
.
extend
(
group
.
_basic_ops
)
num_extra_inputs
=
sum
(
op
.
num_extra_inputs
for
op
in
ops
)
num_extra_outputs
=
sum
(
op
.
num_extra_outputs
for
op
in
ops
)
if
num_extra_inputs
>
0
or
num_extra_outputs
>
0
:
raise
RuntimeError
(
f
"`Sequential` expects
{
num_extra_inputs
}
extra inputs "
f
"and
{
num_extra_outputs
}
extra outputs, "
"but it contains non-fusible operations"
)
groups
[
idx
]
=
OperationFuser
(
group
)
return
groups
...
...
@@ -190,22 +171,28 @@ class Sequential(torch.nn.Module):
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]:
"""Forward pass"""
# Get current global state
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
with_quantized_compute
else
None
global_state
=
(
with_quantized_compute
,
type
(
recipe
))
# Reset module groups is global state changed
if
self
.
_last_global_state
!=
global_state
:
self
.
_module_groups
=
None
self
.
_last_global_state
=
global_state
# Create module groups if needed
if
self
.
_module_groups
is
None
:
self
.
_module_groups
=
self
.
_make_module_groups
(
self
.
_modules
.
values
()
,
recipe
)
self
.
_module_groups
=
self
.
_make_module_groups
(
self
.
_modules
.
values
())
# Forward pass for each module group
x
=
input
extra_outputs
:
list
[
torch
.
Tensor
]
=
[]
for
module_group
in
self
.
_module_groups
:
x
=
module_group
(
x
,
*
extra_inputs
)
if
isinstance
(
module_group
,
OperationFuser
):
xs
,
extra_inputs
=
(
(
x
,)
+
extra_inputs
[:
module_group
.
num_extra_inputs
],
extra_inputs
[
module_group
.
num_extra_inputs
:],
)
xs
=
module_group
(
*
xs
)
if
isinstance
(
xs
,
tuple
):
x
,
ys
=
xs
[
0
],
xs
[
1
:]
extra_outputs
.
extend
(
ys
)
else
:
x
=
xs
else
:
x
=
module_group
(
x
)
if
extra_outputs
:
return
(
x
,)
+
tuple
(
extra_outputs
)
return
x
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
View file @
87e3e56e
...
...
@@ -60,7 +60,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
.
_rowwise_data
=
rowwise_data
instance
.
_columnwise_data
=
columnwise_data
instance
.
_quantizer
=
quantizer
instance
.
_quantizer
=
quantizer
.
copy
()
if
quantizer
is
not
None
else
None
instance
.
_fp8_dtype
=
fp8_dtype
instance
.
_rowwise_scale_inv
=
rowwise_scale_inv
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
...
...
@@ -125,9 +125,15 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
self
.
_columnwise_scale_inv
=
tensors
[
3
]
return
tensors
[
4
:]
def
get_data_tensors
(
self
):
def
get_data_tensors
(
self
,
rowwise_data
:
bool
=
True
,
columnwise_data
:
bool
=
True
):
"""Get this Tensor's data."""
if
rowwise_data
and
columnwise_data
:
return
self
.
_rowwise_data
,
self
.
_columnwise_data
if
rowwise_data
:
return
self
.
_rowwise_data
if
columnwise_data
:
return
self
.
_columnwise_data
raise
ValueError
(
"No data to get, both rowwise_data and columnwise_data are False"
)
def
_transpose_dq_columnwise_output
(
self
,
columnwise_dq
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Takes dequantized columnwise data and permutes to a rowwise shape"""
...
...
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
View file @
87e3e56e
...
...
@@ -86,7 +86,7 @@ class Float8TensorBase(QuantizedTensorBase):
else
:
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
.
_data
=
data
instance
.
_quantizer
=
quantizer
instance
.
_quantizer
=
quantizer
.
copy
()
if
quantizer
is
not
None
else
None
instance
.
_fp8_dtype
=
fp8_dtype
instance
.
_scale_inv
=
fp8_scale_inv
instance
.
_transpose
=
data_transpose
...
...
@@ -128,9 +128,15 @@ class Float8TensorBase(QuantizedTensorBase):
self
.
_scale_inv
=
tensors
[
2
]
return
tensors
[
3
:]
def
get_data_tensors
(
self
):
def
get_data_tensors
(
self
,
rowwise_data
:
bool
=
True
,
columnwise_data
:
bool
=
True
):
"""Get this Tensor's data."""
if
rowwise_data
and
columnwise_data
:
return
self
.
_data
,
self
.
_transpose
if
rowwise_data
:
return
self
.
_data
if
columnwise_data
:
return
self
.
_transpose
raise
ValueError
(
"No data to get, both rowwise_data and columnwise_data are False"
)
def
dequantize
(
self
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
"""Dequantize to a higher precision."""
...
...
transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
View file @
87e3e56e
...
...
@@ -83,7 +83,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
.
_rowwise_data
=
rowwise_data
instance
.
_columnwise_data
=
columnwise_data
instance
.
_quantizer
=
quantizer
instance
.
_quantizer
=
quantizer
.
copy
()
if
quantizer
is
not
None
else
None
instance
.
_fp8_dtype
=
fp8_dtype
instance
.
_rowwise_scale_inv
=
rowwise_scale_inv
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
...
...
@@ -136,9 +136,15 @@ class MXFP8TensorBase(QuantizedTensorBase):
self
.
_columnwise_scale_inv
=
tensors
[
3
]
return
tensors
[
4
:]
def
get_data_tensors
(
self
):
def
get_data_tensors
(
self
,
rowwise_data
:
bool
=
True
,
columnwise_data
:
bool
=
True
):
"""Get this Tensor's data."""
if
rowwise_data
and
columnwise_data
:
return
self
.
_rowwise_data
,
self
.
_columnwise_data
if
rowwise_data
:
return
self
.
_rowwise_data
if
columnwise_data
:
return
self
.
_columnwise_data
raise
ValueError
(
"No data to get, both rowwise_data and columnwise_data are False"
)
def
dequantize
(
self
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
"""Dequantize to a higher precision."""
...
...
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
87e3e56e
...
...
@@ -524,7 +524,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
def
_set_from_tensor
(
dst
:
Float8BlockwiseQTensor
,
src
:
Float8BlockwiseQTensor
):
dst
.
_rowwise_data
=
src
.
_rowwise_data
dst
.
_columnwise_data
=
src
.
_columnwise_data
dst
.
_quantizer
=
src
.
_quantizer
dst
.
_quantizer
=
src
.
_quantizer
.
copy
()
dst
.
_fp8_dtype
=
src
.
_fp8_dtype
dst
.
_rowwise_scale_inv
=
src
.
_rowwise_scale_inv
dst
.
_columnwise_scale_inv
=
src
.
_columnwise_scale_inv
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
87e3e56e
...
...
@@ -109,10 +109,9 @@ class Float8Quantizer(Quantizer):
# Allocate FP8 data transpose if needed
data_transpose
=
None
if
self
.
columnwise_usage
:
inner_dim
=
data
.
size
(
-
1
)
transpose_shape
=
[
data
.
size
(
-
1
)
]
+
list
(
data
.
shape
[:
-
1
])
data_transpose
=
torch
.
empty
(
inner_dim
,
data
.
numel
()
//
inner_dim
,
transpose_shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
)
...
...
@@ -186,6 +185,12 @@ class Float8Quantizer(Quantizer):
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
DelayedScaling
def
supports_only_rowwise_all_gather
(
self
)
->
bool
:
"""
Float8Quantizer supports only rowwise all-gather
"""
return
True
class
Float8CurrentScalingQuantizer
(
Quantizer
):
"""Builder class for FP8 tensors with per-tensor current scaling
...
...
@@ -231,7 +236,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_epsilon
:
float
=
0.0
,
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
scale
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
scale
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8_tensorwise
else
fp8_dtype
self
.
with_amax_reduction
=
with_amax_reduction
...
...
@@ -363,6 +368,12 @@ class Float8CurrentScalingQuantizer(Quantizer):
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
Float8CurrentScaling
def
supports_only_rowwise_all_gather
(
self
)
->
bool
:
"""
Float8CurrentScalingQuantizer supports only rowwise all-gather
"""
return
True
class
Float8Tensor
(
Float8TensorBase
,
QuantizedTensor
):
"""Experimental tensor class with FP8 data
...
...
@@ -691,7 +702,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# Float8Tensor attributes
self
.
_data
=
tensor
.
_data
self
.
_quantizer
=
tensor
.
_quantizer
self
.
_quantizer
=
tensor
.
_quantizer
.
copy
()
self
.
_fp8_dtype
=
tensor
.
_fp8_dtype
self
.
_scale_inv
=
tensor
.
_scale_inv
self
.
_transpose
=
tensor
.
_transpose
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
87e3e56e
...
...
@@ -100,7 +100,7 @@ class MXFP8Quantizer(Quantizer):
# Allocate FP8 data
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
scale_inv
=
torch
.
zeros
(
scale_inv
=
torch
.
empty
(
round_up_to_nearest_multiple
(
math
.
prod
(
shape
[:
-
1
]),
128
),
round_up_to_nearest_multiple
(
shape
[
-
1
]
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
dtype
=
torch
.
uint8
,
...
...
@@ -112,7 +112,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv
=
None
if
self
.
columnwise_usage
:
columnwise_data
=
torch
.
empty_like
(
data
)
columnwise_scale_inv
=
torch
.
zeros
(
columnwise_scale_inv
=
torch
.
empty
(
round_up_to_nearest_multiple
(
math
.
prod
(
shape
[:
-
1
])
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
round_up_to_nearest_multiple
(
shape
[
-
1
],
128
),
dtype
=
torch
.
uint8
,
...
...
@@ -433,7 +433,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
super
(
MXFP8Tensor
,
type
(
self
)).
data
.
__set__
(
self
,
dummy_tensor
)
self
.
_rowwise_data
=
tensor
.
_rowwise_data
self
.
_columnwise_data
=
tensor
.
_columnwise_data
self
.
_quantizer
=
tensor
.
_quantizer
self
.
_quantizer
=
tensor
.
_quantizer
.
copy
()
self
.
_fp8_dtype
=
tensor
.
_fp8_dtype
self
.
_rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
self
.
_columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
...
...
transformer_engine/pytorch/tensor/quantized_tensor.py
View file @
87e3e56e
...
...
@@ -260,6 +260,10 @@ class Quantizer(abc.ABC):
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
"""Returns recipe class that is compatible with this quantizer"""
def
supports_only_rowwise_all_gather
(
self
)
->
bool
:
"""Returns True if the quantizer supports only rowwise all-gather"""
return
False
class
_QuantizeFunc
(
torch
.
autograd
.
Function
):
"""Cast to FP8 from other dtype"""
...
...
transformer_engine/pytorch/transformer.py
View file @
87e3e56e
...
...
@@ -236,14 +236,23 @@ class TransformerLayer(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_type: Optional[str], default = None
type of normalization to apply to query and key tensors.
Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied.
When 'L2Normalization', L2 normalization is applied to query and key tensors.
When 'RMSNorm', RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach for
QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
epsilon value for normalization of query and key tensors.
Only used when `qk_norm_type` is not None.
qk_norm_before_rope: bool, default = `False`
if set to `True`, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
"""
def
__init__
(
...
...
@@ -293,8 +302,9 @@ class TransformerLayer(torch.nn.Module):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
attn_input_format
:
str
=
"sbhd"
,
name
:
str
=
None
,
use_
qk_norm
:
bool
=
Fals
e
,
qk_norm
_type
:
Optional
[
str
]
=
Non
e
,
qk_norm_eps
:
float
=
1e-6
,
qk_norm_before_rope
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
...
...
@@ -397,8 +407,9 @@ class TransformerLayer(torch.nn.Module):
return_bias
=
not
self
.
parallel_attention_mlp
,
normalization
=
normalization
,
device
=
device
,
use_
qk_norm
=
use_
qk_norm
,
qk_norm
_type
=
qk_norm
_type
,
qk_norm_eps
=
qk_norm_eps
,
qk_norm_before_rope
=
qk_norm_before_rope
,
name
=
name
+
".self_attention"
if
name
is
not
None
else
None
,
)
...
...
@@ -413,8 +424,9 @@ class TransformerLayer(torch.nn.Module):
return_bias
=
True
,
normalization
=
normalization
,
device
=
device
,
use_
qk_norm
=
use_
qk_norm
,
qk_norm
_type
=
qk_norm
_type
,
qk_norm_eps
=
qk_norm_eps
,
qk_norm_before_rope
=
qk_norm_before_rope
,
name
=
name
+
".inter_attention"
if
name
is
not
None
else
None
,
)
...
...
transformer_engine/pytorch/triton/cross_entropy.py
View file @
87e3e56e
...
...
@@ -341,13 +341,17 @@ def cross_entropy_forward(
return
loss
,
_input
def
cross_entropy_backward
(
_input
:
torch
.
Tensor
,
grad_output
:
torch
.
Tensor
):
def
cross_entropy_backward
(
_input
:
torch
.
Tensor
,
grad_output
:
torch
.
Tensor
,
is_cg_capturable
:
bool
=
False
):
"""Backward implementation of cross entropy loss kernel"""
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if
torch
.
equal
(
grad_output
,
torch
.
tensor
(
1.0
,
device
=
grad_output
.
device
)):
# Only check torch.equal when not in CUDA graph capturable mode
if
not
is_cg_capturable
and
torch
.
equal
(
grad_output
,
torch
.
tensor
(
1.0
,
device
=
grad_output
.
device
)
):
pass
else
:
B
,
SQ
,
V
=
_input
.
shape
n_rows
=
B
*
SQ
...
...
transformer_engine/pytorch/triton/permutation.py
View file @
87e3e56e
...
...
@@ -359,7 +359,7 @@ def _permute_kernel(
if
prob
==
0.0
:
# for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded
tl
.
store
(
output_ptr
+
output_off
,
0
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
output_off
,
0.
0
,
mask
=
mask
)
else
:
tl
.
store
(
output_ptr
+
output_off
,
inp
,
mask
=
mask
)
else
:
...
...
transformer_engine/pytorch/utils.py
View file @
87e3e56e
...
...
@@ -45,10 +45,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
for
t
in
tensors
:
if
t
is
not
None
:
# Workaround for double buffering in cpu offload
if
hasattr
(
t
,
"do_not_clear"
):
if
hasattr
(
t
,
"
_
do_not_clear"
):
continue
if
hasattr
(
t
,
"get_data_tensors"
):
if
any
(
hasattr
(
tensor
,
"do_not_clear"
)
for
tensor
in
t
.
get_data_tensors
()):
if
any
(
hasattr
(
tensor
,
"
_
do_not_clear"
)
for
tensor
in
t
.
get_data_tensors
()):
continue
if
hasattr
(
t
,
"clear"
):
...
...
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment