Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87e3e56e
Commit
87e3e56e
authored
Aug 27, 2025
by
yuguo
Browse files
Merge commit '
734bcedd
' of...
Merge commit '
734bcedd
' of
https://github.com/NVIDIA/TransformerEngine
parents
2f11bd2e
734bcedd
Changes
217
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
672 additions
and
242 deletions
+672
-242
transformer_engine/pytorch/ops/basic/all_reduce.py
transformer_engine/pytorch/ops/basic/all_reduce.py
+1
-1
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+104
-93
transformer_engine/pytorch/ops/basic/bias.py
transformer_engine/pytorch/ops/basic/bias.py
+10
-27
transformer_engine/pytorch/ops/basic/constant_scale.py
transformer_engine/pytorch/ops/basic/constant_scale.py
+40
-0
transformer_engine/pytorch/ops/basic/dropout.py
transformer_engine/pytorch/ops/basic/dropout.py
+67
-0
transformer_engine/pytorch/ops/basic/identity.py
transformer_engine/pytorch/ops/basic/identity.py
+1
-1
transformer_engine/pytorch/ops/basic/l2normalization.py
transformer_engine/pytorch/ops/basic/l2normalization.py
+10
-4
transformer_engine/pytorch/ops/basic/layer_norm.py
transformer_engine/pytorch/ops/basic/layer_norm.py
+6
-16
transformer_engine/pytorch/ops/basic/make_extra_output.py
transformer_engine/pytorch/ops/basic/make_extra_output.py
+23
-9
transformer_engine/pytorch/ops/basic/quantize.py
transformer_engine/pytorch/ops/basic/quantize.py
+3
-2
transformer_engine/pytorch/ops/basic/reduce_scatter.py
transformer_engine/pytorch/ops/basic/reduce_scatter.py
+1
-1
transformer_engine/pytorch/ops/basic/reshape.py
transformer_engine/pytorch/ops/basic/reshape.py
+3
-2
transformer_engine/pytorch/ops/basic/rmsnorm.py
transformer_engine/pytorch/ops/basic/rmsnorm.py
+6
-16
transformer_engine/pytorch/ops/fused/__init__.py
transformer_engine/pytorch/ops/fused/__init__.py
+11
-3
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
...rmer_engine/pytorch/ops/fused/backward_activation_bias.py
+8
-13
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+5
-3
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+155
-0
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+18
-24
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+24
-27
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+176
-0
No files found.
transformer_engine/pytorch/ops/basic/all_reduce.py
View file @
87e3e56e
...
...
@@ -42,7 +42,7 @@ class AllReduce(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
87e3e56e
...
...
@@ -22,9 +22,7 @@ from ...distributed import (
from
...fp8
import
FP8GlobalStateManager
,
Recipe
from
...module.base
import
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
...tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
...
...
@@ -291,10 +289,19 @@ class BasicLinear(BasicOperation):
# Quantize if needed
if
self
.
_with_quantized_weight
:
quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
if
quantizer
is
None
:
raise
RuntimeError
(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"within fp8_model_init, but the forward pass was not "
"performed within fp8_autocast."
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
torch
.
is_grad_enabled
(),
)
quantizer
.
internal
=
False
with
torch
.
no_grad
():
weight
=
quantizer
(
weight
)
...
...
@@ -303,72 +310,52 @@ class BasicLinear(BasicOperation):
weight
=
torch
.
nn
.
Parameter
(
weight
)
self
.
weight
=
weight
def
pre_first_forward
(
self
,
*
,
recipe
:
Optional
[
Recipe
],
)
->
None
:
super
().
pre_first_forward
(
recipe
=
recipe
)
# Initialize weights if needed
weight
=
self
.
weight
if
weight
.
device
.
type
==
"meta"
:
def
pre_first_fuser_forward
(
self
)
->
None
:
super
().
pre_first_fuser_forward
()
if
self
.
weight
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
weight
=
self
.
weight
# Configure quantizers
if
recipe
is
not
None
:
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
def
reset_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
])
->
None
:
super
().
reset_recipe_state
(
recipe
=
recipe
)
# Specify required tensor formats
# Input/grad output quantizers use internal tensors
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
if
input_quantizer
is
not
None
:
input_quantizer
.
internal
=
True
weight_quantizer
.
internal
=
True
if
grad_output_quantizer
is
not
None
:
grad_output_quantizer
.
internal
=
True
# Recipe-specific configuration
if
recipe
.
float8_current_scaling
():
if
any
(
not
isinstance
(
q
,
Float8CurrentScalingQuantizer
)
for
q
in
(
input_quantizer
,
weight_quantizer
,
grad_output_quantizer
)
):
raise
RuntimeError
(
"FP8 current-scaling recipe is enabled, "
f
"but input quantizer is
{
input_quantizer
.
__class__
.
__name__
}
, "
f
"weight quantizer is
{
weight_quantizer
.
__class__
.
__name__
}
, "
f
"grad output quantizer is
{
grad_output_quantizer
.
__class__
.
__name__
}
"
)
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
grad_output_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if
isinstance
(
weight_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
)
and
isinstance
(
weight
,
Float8TensorBase
):
weight
.
_quantizer
=
weight_quantizer
# Handle weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
if
weight_quantizer
is
None
:
pass
elif
is_quantized_tensor
(
getattr
(
self
,
"weight"
,
None
)):
# Make sure weight param has correct quantizer
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
torch
.
is_grad_enabled
())
weight_quantizer
.
internal
=
False
self
.
weight
.
update_quantizer
(
weight_quantizer
.
copy
())
else
:
# Use internal tensors if quantized weights will not be
# exposed externally
weight_quantizer
.
internal
=
(
not
FP8GlobalStateManager
.
with_fp8_parameters
()
and
not
getattr
(
self
,
"_with_quantized_weight"
,
False
)
)
@
staticmethod
def
_functional_forward
(
input
:
torch
.
Tensor
,
# pylint: disable=redefined-builtin
weight
:
torch
.
Tensor
,
*
,
alpha
:
float
=
1.0
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
# pylint: disable=unused-argument
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
accumulate_into_out
:
bool
=
False
,
tensor_parallel_mode
:
Optional
[
str
]
=
None
,
tensor_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
...
...
@@ -388,6 +375,8 @@ class BasicLinear(BasicOperation):
Input tensor
weight: torch.Tensor
Weight tensor
alpha: float, default = 1.0
Scaling factor applied to the result of the GEMM
bias: torch.Tensor, optional
Bias tensor
device: torch.device, default = default CUDA device
...
...
@@ -396,6 +385,8 @@ class BasicLinear(BasicOperation):
Tensor datatype
out: torch.Tensor, optional
Output tensor
beta: float, optional
Scaling factor applied to original value of out when accumulating into it
accumulate_into_out: bool, default = `False`
Add result to output tensor instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
...
...
@@ -441,7 +432,7 @@ class BasicLinear(BasicOperation):
if
dtype
is
None
:
if
out
is
not
None
and
isinstance
(
out
,
torch
.
Tensor
):
dtype
=
out
.
dtype
elif
weight
is
not
None
and
isinstance
(
ou
t
,
torch
.
Tensor
):
elif
weight
is
not
None
and
isinstance
(
weigh
t
,
torch
.
Tensor
):
dtype
=
weight
.
dtype
else
:
raise
ValueError
(
...
...
@@ -516,18 +507,11 @@ class BasicLinear(BasicOperation):
raise
ValueError
(
"Output tensor is quantized, but quantizer was not provided"
)
else
:
output_quantizer
=
None
if
isinstance
(
output_quantizer
,
MXFP8Quantizer
):
raise
RuntimeError
(
"Attempting to generate MXFP8 output tensor, "
"but GEMM with MXFP8 output is not supported"
)
if
isinstance
(
output_quantizer
,
Float8BlockQuantizer
):
raise
RuntimeError
(
"Attempting to generate Float8BlockQuantized output tensor, "
"but GEMM with Float8BlockQuantized output is not supported"
)
if
output_quantizer
is
not
None
:
if
not
isinstance
(
output_quantizer
,
Float8Quantizer
):
raise
RuntimeError
(
"Attempting to generate quantized output tensor with unsupported quantizer"
)
output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# Check if accumulating into output tensor
...
...
@@ -552,6 +536,8 @@ class BasicLinear(BasicOperation):
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
output_quantizer
,
alpha
=
alpha
,
beta
=
beta
,
accumulate
=
accumulate_into_out
,
out
=
y
,
bias
=
bias
,
...
...
@@ -589,13 +575,17 @@ class BasicLinear(BasicOperation):
input
:
Optional
[
torch
.
Tensor
],
# pylint: disable=redefined-builtin
weight
:
Optional
[
torch
.
Tensor
],
*
,
grad_input_alpha
:
Optional
[
float
]
=
None
,
input_requires_grad
:
bool
=
True
,
grad_weight_alpha
:
Optional
[
float
]
=
None
,
weight_requires_grad
:
bool
=
True
,
device
:
Optional
[
torch
.
device
]
=
None
,
# pylint: disable=unused-argument
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
grad_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
grad_weight_beta
:
Optional
[
float
]
=
None
,
accumulate_into_grad_weight
:
bool
=
False
,
grad_input
:
Optional
[
torch
.
Tensor
]
=
None
,
grad_input_beta
:
Optional
[
float
]
=
None
,
accumulate_into_grad_input
:
bool
=
False
,
tensor_parallel_mode
:
Optional
[
str
]
=
None
,
tensor_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
...
...
@@ -618,8 +608,12 @@ class BasicLinear(BasicOperation):
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
input.
grad_input_alpha: float, optional
Scaling factor applied to the result of the dgrad GEMM
input_requires_grad: bool
Whether to compute loss gradient w.r.t. input tensor
grad_weight_alpha: float, optional
Scaling factor applied to the result of the wgrad GEMM
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
device: torch.device, default = default CUDA device
...
...
@@ -628,10 +622,14 @@ class BasicLinear(BasicOperation):
Tensor datatype
grad_weight: torch.Tensor, optional
Loss gradient w.r.t. weight tensor
grad_weight_beta: float, optional
Scaling factor applied to original value of grad_weight when accumulating into it
accumulate_into_grad_weight: bool, default = `False`
Add result to weight grad instead of overwriting
grad_input: torch.Tensor, optional
Loss gradient w.r.t. input tensor
grad_input_beta: float, optional
Scaling factor applied to original value of grad_input when accumulating into it
accumulate_into_grad_input: bool, default = `False`
Add result to input grad instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
...
...
@@ -801,11 +799,12 @@ class BasicLinear(BasicOperation):
)
else
:
grad_input_quantizer
=
None
if
isinstance
(
grad_input_quantizer
,
MXFP8Quantizer
):
raise
RuntimeError
(
"Attempting to generate MXFP8 grad input tensor, "
"but GEMM with MXFP8 output is not supported"
)
if
grad_input_quantizer
is
not
None
:
if
not
isinstance
(
grad_input_quantizer
,
Float8Quantizer
):
raise
RuntimeError
(
"Attempting to generate quantized grad input tensor "
"with unsupported quantizer"
)
# Check if accumulating into grad input tensor
if
accumulate_into_grad_input
:
...
...
@@ -827,6 +826,8 @@ class BasicLinear(BasicOperation):
get_workspace
(),
out_dtype
=
dtype
,
quantization_params
=
grad_input_quantizer
,
alpha
=
grad_input_alpha
,
beta
=
grad_input_beta
,
accumulate
=
accumulate_into_grad_input
,
layout
=
"NN"
,
out
=
dx
,
...
...
@@ -877,6 +878,8 @@ class BasicLinear(BasicOperation):
dy
,
get_workspace
(),
out_dtype
=
dw_dtype
,
alpha
=
grad_weight_alpha
,
beta
=
grad_weight_beta
,
accumulate
=
accumulate_into_grad_weight
,
layout
=
"NT"
,
out
=
dw
,
...
...
@@ -894,7 +897,7 @@ class BasicLinear(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
...
...
@@ -903,27 +906,34 @@ class BasicLinear(BasicOperation):
weight_requires_grad
=
ctx
.
requires_grad
and
self
.
weight
.
requires_grad
# FP8 metadata
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_output_quantizer
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_quantizer
=
None
weight_quantizer
=
None
output_quantizer
=
None
grad_output_quantizer
=
None
grad_input_quantizer
=
None
if
with_quantized_compute
:
# Get quantizers
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_input_quantizer
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
grad_output_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
# Get autocast dtype if needed
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
...
...
@@ -947,15 +957,16 @@ class BasicLinear(BasicOperation):
)
# Save state for backward pass
ctx
.
save_for_backward
(
x_local
,
w
)
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
input_quantizer
=
input_quantizer
ctx
.
weight_quantizer
=
weight_quantizer
ctx
.
grad_output_quantizer
=
grad_output_quantizer
ctx
.
grad_input_quantizer
=
grad_input_quantizer
ctx
.
dtype
=
dtype
ctx
.
input_requires_grad
=
input_requires_grad
ctx
.
weight_requires_grad
=
weight_requires_grad
if
ctx
.
requires_grad
:
ctx
.
save_for_backward
(
x_local
,
w
)
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
input_quantizer
=
input_quantizer
ctx
.
weight_quantizer
=
weight_quantizer
ctx
.
grad_output_quantizer
=
grad_output_quantizer
ctx
.
grad_input_quantizer
=
grad_input_quantizer
ctx
.
dtype
=
dtype
ctx
.
input_requires_grad
=
input_requires_grad
ctx
.
weight_requires_grad
=
weight_requires_grad
return
output
...
...
transformer_engine/pytorch/ops/basic/bias.py
View file @
87e3e56e
...
...
@@ -10,15 +10,8 @@ from typing import Optional
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
OperationContext
,
)
from
...utils
import
(
canonicalize_device
,
canonicalize_dtype
,
)
from
...fp8
import
FP8GlobalStateManager
from
..op
import
BasicOperation
,
OperationContext
from
...utils
import
canonicalize_device
,
canonicalize_dtype
from
...tensor
import
Quantizer
...
...
@@ -114,8 +107,8 @@ class Bias(BasicOperation):
bias
=
torch
.
nn
.
Parameter
(
bias
)
self
.
bias
=
bias
def
pre_first_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_first_forward
(
*
args
,
**
kwargs
)
def
pre_first_
fuser_
forward
(
self
)
->
None
:
super
().
pre_first_
fuser_
forward
()
if
self
.
bias
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
...
...
@@ -123,24 +116,14 @@ class Bias(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
x
=
input_
b
=
self
.
bias
.
view
([
1
]
*
(
x
.
dim
()
-
1
)
+
[
self
.
local_size
])
# Check if backward pass is needed
requires_grad
=
ctx
.
requires_grad
# Check if previous op quantizes its output's gradient
grad_input_quantizer
=
None
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
grad_input_quantizer
=
prev_op_grad_input_quantizer
if
requires_grad
:
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
grad_input_quantizer
=
grad_input_quantizer
if
ctx
.
requires_grad
:
ctx
.
grad_input_quantizer
=
prev_op_grad_output_quantizer
return
x
+
b
...
...
@@ -152,10 +135,10 @@ class Bias(BasicOperation):
dy
=
grad_output
if
dy
.
dim
()
>
1
:
quantizer
=
ctx
.
grad_input_quantizer
if
ctx
.
with_quantized_compute
and
quantizer
is
not
None
:
db
,
dy
=
tex
.
bgrad_quantize
(
dy
,
quantizer
)
else
:
if
quantizer
is
None
:
db
=
dy
.
sum
(
tuple
(
range
(
dy
.
dim
()
-
1
)))
else
:
db
,
dy
=
tex
.
bgrad_quantize
(
dy
,
quantizer
)
else
:
db
=
dy
return
dy
,
(
db
,)
transformer_engine/pytorch/ops/basic/constant_scale.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for constant scaling."""
from
__future__
import
annotations
from
typing
import
Optional
import
torch
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
class
ConstantScale
(
BasicOperation
):
"""Multiply by a constant"""
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
()
self
.
scale
=
scale
def
op_forward
(
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_output_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
return
input_
*
self
.
scale
def
op_backward
(
self
,
ctx
:
OperationContext
,
grad_output
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
return
grad_output
*
self
.
scale
,
()
transformer_engine/pytorch/ops/basic/dropout.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for dropout."""
from
__future__
import
annotations
from
typing
import
Optional
import
torch
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
class
Dropout
(
BasicOperation
):
"""Randomly zero out tensor entries during training
During training, tensor entries are randomly set to zero with
probability :math:`p` and remaining entries are scaled by
:math:`1/(1-p)`.
"""
def
__init__
(
self
,
p
:
float
)
->
None
:
super
().
__init__
()
self
.
dropout_probability
=
p
def
op_forward
(
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_output_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
# Compute dropout if training
out
=
input_
is_training
=
self
.
training
mask
=
None
if
is_training
:
keep_prob
=
1
-
self
.
dropout_probability
mask
=
torch
.
empty_like
(
input_
)
mask
.
bernoulli_
(
keep_prob
)
mask
*=
1
/
keep_prob
out
=
out
*
mask
# Save context for backward
if
ctx
.
requires_grad
:
ctx
.
save_for_backward
(
mask
)
ctx
.
is_training
=
is_training
return
out
def
op_backward
(
self
,
ctx
:
OperationContext
,
grad_output
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
(
mask
,)
=
ctx
.
saved_tensors
grad_input
=
grad_output
if
ctx
.
is_training
:
grad_input
=
grad_input
*
mask
return
grad_input
,
()
transformer_engine/pytorch/ops/basic/identity.py
View file @
87e3e56e
...
...
@@ -23,7 +23,7 @@ class Identity(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
return
input_
...
...
transformer_engine/pytorch/ops/basic/l2normalization.py
View file @
87e3e56e
...
...
@@ -6,10 +6,12 @@
from
__future__
import
annotations
from
typing
import
Optional
import
os
import
torch
from
...utils
import
clear_tensor_data
from
...
import
torch_version
from
.._common
import
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
from
...jit
import
(
...
...
@@ -60,7 +62,11 @@ class L2Normalization(BasicOperation):
# JIT warmup for L2Normalization fused operations
if
seq_length
and
micro_batch_size
:
if
torch
.
cuda
.
is_available
():
if
(
torch
.
cuda
.
is_available
()
and
torch_version
()
>=
(
2
,
0
,
0
)
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
)))
):
set_jit_fusion_options
()
# For L2Normalization, we don't know the hidden size until forward pass,
# but we can warm up with common sizes. For QK normalization, this will be
...
...
@@ -74,7 +80,7 @@ class L2Normalization(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
# Use input directly - torch.compile can handle multi-dimensional tensors
...
...
@@ -86,7 +92,7 @@ class L2Normalization(BasicOperation):
# Compute L2 normalization using fused implementation
# L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps)
if
requires_grad
:
# Training: use version that returns
both
output and intermediate values
# Training: use version that returns output and intermediate values
for backward pass
y
,
rsqrt_norm
=
l2normalization_fwd_fused
(
x
,
self
.
eps
)
else
:
# Inference: use lightweight version that only returns output
...
...
@@ -110,7 +116,7 @@ class L2Normalization(BasicOperation):
dy
=
maybe_dequantize
(
grad_output
)
# Compute L2 norm backward pass using fused implementation
# Compute L2 norm backward pass using fused implementation
- recalculates l2_norm_squared_eps
dx
=
l2normalization_backward_fused
(
dy
,
x
,
rsqrt_norm
,
self
.
eps
)
# Clear saved tensors if possible
...
...
transformer_engine/pytorch/ops/basic/layer_norm.py
View file @
87e3e56e
...
...
@@ -13,7 +13,6 @@ from typing import Optional
import
torch
from
transformer_engine_torch
import
layernorm_bwd
,
layernorm_fwd
from
...fp8
import
FP8GlobalStateManager
from
...constants
import
TE_DType
from
...utils
import
(
canonicalize_device
,
...
...
@@ -168,8 +167,8 @@ class LayerNorm(BasicOperation):
self
.
weight
=
weight
self
.
bias
=
bias
def
pre_first_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_first_forward
(
*
args
,
**
kwargs
)
def
pre_first_
fuser_
forward
(
self
)
->
None
:
super
().
pre_first_
fuser_
forward
()
if
self
.
weight
.
device
.
type
==
"meta"
or
self
.
bias
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
...
...
@@ -177,7 +176,7 @@ class LayerNorm(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
if
is_in_onnx_export_mode
():
...
...
@@ -200,31 +199,22 @@ class LayerNorm(BasicOperation):
w
=
maybe_dequantize
(
self
.
weight
,
dtype
).
view
((
inner_dim
,))
b
=
maybe_dequantize
(
self
.
bias
,
dtype
).
view
((
inner_dim
,))
# Check if backward pass is needed
requires_grad
=
ctx
.
requires_grad
# Check if output is quantized
output_quantizer
=
None
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
output_quantizer
=
next_op_input_quantizer
# Compute layer norm
sm_margin
=
self
.
_sm_margins
[
"forward"
if
requires_grad
else
"inference"
]
sm_margin
=
self
.
_sm_margins
[
"forward"
if
ctx
.
requires_grad
else
"inference"
]
y
,
means
,
rstdevs
=
layernorm_fwd
(
x
,
w
,
b
,
self
.
eps
,
None
,
out
put_quantizer
,
next_op_in
put_quantizer
,
TE_DType
[
dtype
],
sm_margin
,
self
.
zero_centered_gamma
,
)
# Save state for backward pass
if
requires_grad
:
if
ctx
.
requires_grad
:
ctx
.
save_for_backward
(
x
,
means
,
rstdevs
)
ctx
.
dtype
=
dtype
...
...
transformer_engine/pytorch/ops/basic/make_extra_output.py
View file @
87e3e56e
...
...
@@ -22,14 +22,20 @@ class MakeExtraOutput(BasicOperation):
If this operation is included in the operation fuser, then the
operation fuser will return the intermediate tensor as an extra
tensor output. In the backward pass, the gradient is directly
accumulated into the gradient w.r.t. the extra output.
tensor output.
This operation is considered an advanced feature and most users
are discouraged from using it. In-place operations break some
autograd assumptions and they can result in subtle, esoteric bugs.
In the backward pass, the gradient may be directly
accumulated into the gradient w.r.t. the extra output. This is
controlled by the in_place kwarg. Currently, the BackwardLinearAdd
fusion is able to happen only with in_place=True.
Compare to `AddInPlace`, which does a similar operation in the
Using this operation with in_place=True is
considered an advanced feature. Most users are discouraged
from enabling it in-place gradient accumulation, as in-place
operations break some autograd assumptions and they can result
in subtle, esoteric bugs.
Compare to `AddExtraInput`, which does a similar operation in the
backward pass.
"""
...
...
@@ -37,6 +43,10 @@ class MakeExtraOutput(BasicOperation):
# Operation expects buffer for output tensor
num_extra_outputs
:
int
=
1
def
__init__
(
self
,
*
,
in_place
:
bool
=
False
):
super
().
__init__
()
self
.
_in_place
:
bool
=
in_place
def
op_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
raise
RuntimeError
(
"{self.__class__.__name__} operation has "
...
...
@@ -59,7 +69,7 @@ class MakeExtraOutput(BasicOperation):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
...
...
@@ -76,6 +86,10 @@ class MakeExtraOutput(BasicOperation):
Iterable
[
Iterable
[
Optional
[
torch
.
Tensor
]]],
Iterable
[
Iterable
[
Optional
[
torch
.
Tensor
]]],
]:
grad_input
=
basic_op_grad_extra_outputs
[
0
][
0
]
grad_input
+=
grad_output
grad_extra_output
=
basic_op_grad_extra_outputs
[
0
][
0
]
if
self
.
_in_place
:
grad_extra_output
+=
grad_output
grad_input
=
grad_extra_output
else
:
grad_input
=
grad_extra_output
+
grad_output
return
grad_input
,
[()],
[()]
transformer_engine/pytorch/ops/basic/quantize.py
View file @
87e3e56e
...
...
@@ -50,7 +50,7 @@ class Quantize(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
...
...
@@ -64,7 +64,8 @@ class Quantize(BasicOperation):
if
quantize_forward
and
not
is_quantized_tensor
(
out
):
out
=
self
.
get_quantizer
(
"forward"
,
0
)(
out
)
ctx
.
quantize_backward
=
quantize_backward
if
ctx
.
requires_grad
:
ctx
.
quantize_backward
=
quantize_backward
return
out
def
op_backward
(
...
...
transformer_engine/pytorch/ops/basic/reduce_scatter.py
View file @
87e3e56e
...
...
@@ -40,7 +40,7 @@ class ReduceScatter(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
...
...
transformer_engine/pytorch/ops/basic/reshape.py
View file @
87e3e56e
...
...
@@ -38,10 +38,11 @@ class Reshape(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
ctx
.
input_shape
=
input_
.
size
()
if
ctx
.
requires_grad
:
ctx
.
input_shape
=
input_
.
size
()
return
input_
.
reshape
(
*
self
.
_shape
)
def
op_backward
(
...
...
transformer_engine/pytorch/ops/basic/rmsnorm.py
View file @
87e3e56e
...
...
@@ -13,7 +13,6 @@ from typing import Optional
import
torch
from
transformer_engine_torch
import
rmsnorm_bwd
,
rmsnorm_fwd
from
...fp8
import
FP8GlobalStateManager
from
...constants
import
TE_DType
from
...utils
import
(
canonicalize_device
,
...
...
@@ -151,8 +150,8 @@ class RMSNorm(BasicOperation):
weight
=
torch
.
nn
.
Parameter
(
weight
)
self
.
weight
=
weight
def
pre_first_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_first_forward
(
*
args
,
**
kwargs
)
def
pre_first_
fuser_
forward
(
self
)
->
None
:
super
().
pre_first_
fuser_
forward
()
if
self
.
weight
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
...
...
@@ -160,7 +159,7 @@ class RMSNorm(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
)
->
torch
.
Tensor
:
if
is_in_onnx_export_mode
():
...
...
@@ -182,30 +181,21 @@ class RMSNorm(BasicOperation):
x
=
maybe_dequantize
(
input_
.
contiguous
(),
dtype
).
view
((
-
1
,
inner_dim
))
w
=
maybe_dequantize
(
self
.
weight
,
dtype
).
view
((
inner_dim
,))
# Check if backward pass is needed
requires_grad
=
ctx
.
requires_grad
# Check if output is quantized
output_quantizer
=
None
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
output_quantizer
=
next_op_input_quantizer
# Compute RMSNorm
sm_margin
=
self
.
_sm_margins
[
"forward"
if
requires_grad
else
"inference"
]
sm_margin
=
self
.
_sm_margins
[
"forward"
if
ctx
.
requires_grad
else
"inference"
]
y
,
_
,
rstdevs
=
rmsnorm_fwd
(
x
,
w
,
self
.
eps
,
None
,
out
put_quantizer
,
next_op_in
put_quantizer
,
TE_DType
[
dtype
],
sm_margin
,
self
.
zero_centered_gamma
,
)
# Save state for backward pass
if
requires_grad
:
if
ctx
.
requires_grad
:
ctx
.
save_for_backward
(
x
,
rstdevs
)
ctx
.
dtype
=
dtype
...
...
transformer_engine/pytorch/ops/fused/__init__.py
View file @
87e3e56e
...
...
@@ -4,14 +4,18 @@
"""Compound tensor operation supported by the operation fuser."""
from
.backward_
bias_
activation
import
(
Backward
Bias
Activation
,
fuse_backward_
bias_
activation
,
from
.backward_activation
_bias
import
(
BackwardActivation
Bias
,
fuse_backward_activation
_bias
,
)
from
.backward_linear_add
import
(
BackwardLinearAdd
,
fuse_backward_linear_add
,
)
from
.backward_linear_scale
import
(
BackwardLinearScale
,
fuse_backward_linear_scale
,
)
from
.forward_linear_bias_activation
import
(
ForwardLinearBiasActivation
,
fuse_forward_linear_bias_activation
,
...
...
@@ -20,6 +24,10 @@ from .forward_linear_bias_add import (
ForwardLinearBiasAdd
,
fuse_forward_linear_bias_add
,
)
from
.forward_linear_scale_add
import
(
ForwardLinearScaleAdd
,
fuse_forward_linear_scale_add
,
)
from
.userbuffers_backward_linear
import
(
UserbuffersBackwardLinear
,
fuse_userbuffers_backward_linear
,
...
...
transformer_engine/pytorch/ops/fused/backward_
bias_
activation.py
→
transformer_engine/pytorch/ops/fused/backward_activation
_bias
.py
View file @
87e3e56e
...
...
@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Fused backward d
bias + dact
+ quantize."""
"""Fused backward d
act + dbias
+ quantize."""
from
__future__
import
annotations
from
typing
import
Optional
...
...
@@ -29,8 +29,8 @@ _fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu}
_fusible_activations
=
tuple
(
_fused_activations
.
keys
())
class
Backward
Bias
Activation
(
FusedOperation
):
"""Fused backward d
bias + dact
+ quantize
class
BackwardActivation
Bias
(
FusedOperation
):
"""Fused backward d
act + dbias
+ quantize
Uses the next operation's input quantizer.
...
...
@@ -66,15 +66,10 @@ class BackwardBiasActivation(FusedOperation):
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
act_input
.
dtype
)
# Get previous op quantizer
if
not
bias_op_ctx
.
with_quantized_compute
:
raise
RuntimeError
(
"BackwardBiasActivation requires quantized compute, "
"but Bias context has it disabled"
)
quantizer
=
bias_op_ctx
.
grad_input_quantizer
if
quantizer
is
None
:
raise
RuntimeError
(
"Backward
Bias
Activation requires previous op's grad output quantizer, "
"BackwardActivation
Bias
requires previous op's grad output quantizer, "
"but Bias context has no quantizer"
)
...
...
@@ -87,11 +82,11 @@ class BackwardBiasActivation(FusedOperation):
return
dx
,
[(),
(
db
,)],
[(),
()]
def
fuse_backward_
bias_
activation
(
def
fuse_backward_activation
_bias
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
recipe
:
Optional
[
Recipe
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Fused backward d
bias + dact
+ quantize
"""Fused backward d
act + dbias
+ quantize
Parameters
----------
...
...
@@ -109,7 +104,7 @@ def fuse_backward_bias_activation(
"""
# Check if recipe supports bias activation fusion
if
recipe
is
None
or
not
(
recipe
.
delayed
()
or
recipe
.
mxfp8
())
:
if
recipe
is
None
:
return
ops
# Scan through ops, fusing if possible
...
...
@@ -138,7 +133,7 @@ def fuse_backward_bias_activation(
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
Backward
Bias
Activation
(
op
=
BackwardActivation
Bias
(
activation
=
window
[
0
][
0
],
bias
=
window
[
1
][
0
],
)
...
...
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
87e3e56e
...
...
@@ -29,10 +29,10 @@ class BackwardLinearAdd(FusedOperation):
def
__init__
(
self
,
*
,
linear
:
BasicLinear
,
backward_add
:
MakeExtraOutput
,
linear
:
BasicLinear
,
)
->
None
:
super
().
__init__
((
linear
,
backward_add
))
super
().
__init__
((
backward_add
,
linear
))
def
fuser_backward
(
self
,
...
...
@@ -47,7 +47,7 @@ class BackwardLinearAdd(FusedOperation):
]:
# Get basic operations
linear_op
=
self
.
basic_ops
[
0
]
linear_op
=
self
.
basic_ops
[
1
]
linear_op_ctx
=
basic_op_ctxs
[
0
]
# Saved tensors from forward pass
...
...
@@ -139,6 +139,8 @@ def fuse_backward_linear_add(
op
,
_
=
ops
[
0
]
if
not
isinstance
(
op
,
MakeExtraOutput
):
continue
if
not
op
.
_in_place
:
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
...
...
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward dgrad GEMM + scale."""
from
__future__
import
annotations
from
typing
import
Optional
import
torch
from
..basic
import
BasicLinear
,
ConstantScale
from
..op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...utils
import
clear_tensor_data
class
BackwardLinearScale
(
FusedOperation
):
"""Fused backward dgrad GEMM + scale
Column tensor parallelism is not supported since that requires
communication immediately after the dgrad GEMM.
"""
def
__init__
(
self
,
*
,
scale
:
ConstantScale
,
linear
:
BasicLinear
,
)
->
None
:
super
().
__init__
((
linear
,
scale
))
def
fuser_backward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
grad_output
:
torch
.
Tensor
,
*
,
basic_op_grad_extra_outputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
)
->
tuple
[
torch
.
Tensor
,
list
[
tuple
[
Optional
[
torch
.
Tensor
],
...]],
list
[
tuple
[()]],
]:
# Get basic operations
linear_op
=
self
.
basic_ops
[
0
]
linear_op_ctx
=
basic_op_ctxs
[
1
]
scale_op
=
self
.
basic_ops
[
1
]
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
# wgrad fusion
accumulate_into_main_grad
=
linear_op
.
_accumulate_into_main_grad
grad_weight
=
None
if
linear_op_ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
linear_op
.
weight
,
"__fsdp_param__"
):
linear_op
.
weight
.
main_grad
=
linear_op
.
weight
.
get_main_grad
()
if
not
hasattr
(
linear_op
.
weight
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight
=
linear_op
.
weight
.
main_grad
.
detach
()
else
:
accumulate_into_main_grad
=
False
# Linear backward pass
grad_input
,
grad_weight
=
BasicLinear
.
_functional_backward
(
grad_output
=
grad_output
,
input
=
x_local
,
weight
=
w
,
input_requires_grad
=
linear_op_ctx
.
input_requires_grad
,
grad_input_alpha
=
scale_op
.
scale
,
weight_requires_grad
=
linear_op_ctx
.
weight_requires_grad
,
grad_weight_alpha
=
scale_op
.
scale
,
dtype
=
linear_op_ctx
.
dtype
,
grad_weight
=
grad_weight
,
accumulate_into_grad_weight
=
accumulate_into_main_grad
,
tensor_parallel_mode
=
linear_op
.
tensor_parallel_mode
,
tensor_parallel_group
=
linear_op
.
tensor_parallel_group
,
sequence_parallel
=
linear_op
.
sequence_parallel
,
with_quantized_compute
=
linear_op_ctx
.
with_quantized_compute
,
input_quantizer
=
linear_op_ctx
.
input_quantizer
,
weight_quantizer
=
linear_op_ctx
.
weight_quantizer
,
grad_output_quantizer
=
linear_op_ctx
.
grad_output_quantizer
,
grad_input_quantizer
=
linear_op_ctx
.
grad_input_quantizer
,
)
if
accumulate_into_main_grad
:
grad_weight
=
None
# Clear input tensor if possible
clear_tensor_data
(
x_local
)
return
grad_input
,
[(),
(
grad_weight
,)],
[(),
()]
def
fuse_backward_linear_scale
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Fused backward dgrad GEMM + constant scale
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Scan through ops, fusing if possible
out
=
[]
window
=
[]
while
len
(
ops
)
>=
2
:
out
.
extend
(
window
)
# Check if first op is constant scale
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
ConstantScale
):
continue
# Check if second op is linear
op
,
_
=
ops
[
0
]
if
not
isinstance
(
op
,
BasicLinear
):
continue
if
op
.
tensor_parallel_mode
==
"column"
:
# Column tensor-parallelism requires communication after the dgrad GEMM
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
BackwardLinearScale
(
scale
=
window
[
0
][
0
],
linear
=
window
[
1
][
0
],
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
87e3e56e
...
...
@@ -59,7 +59,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
...
...
@@ -89,18 +89,12 @@ class ForwardLinearBiasActivation(FusedOperation):
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# FP8 metadata
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_output_quantizer
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_quantizer
=
None
weight_quantizer
=
None
output_quantizer
=
None
grad_output_quantizer
=
None
grad_input_quantizer
=
None
if
with_quantized_compute
:
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_input_quantizer
# Get autocast dtype if needed
if
torch
.
is_autocast_enabled
():
...
...
@@ -126,18 +120,18 @@ class ForwardLinearBiasActivation(FusedOperation):
)
# Save state for backward pass
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input
_quantize
r
=
input_quantizer
linear_op_ctx
.
weigh
t_quantizer
=
weigh
t_quantizer
linear_op_ctx
.
grad_outpu
t_quantizer
=
grad_outpu
t_quantizer
linear_op_ctx
.
grad_
in
put_quantizer
=
grad_
in
put_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weigh
t_requires_grad
=
weigh
t_requires_grad
if
bias_op
is
not
None
:
bias_op
_ctx
.
with_quantized_compute
=
with_quantized_compute
bias_op_ctx
.
grad_input_quantizer
=
linear_op
.
get_grad_
in
put_quantizer
()
if
linear_op_ctx
.
requires_grad
:
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with
_quantize
d_compute
=
with_quantized_compute
linear_op_ctx
.
inpu
t_quantizer
=
inpu
t_quantizer
linear_op_ctx
.
weigh
t_quantizer
=
weigh
t_quantizer
linear_op_ctx
.
grad_
out
put_quantizer
=
grad_
out
put_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
inpu
t_requires_grad
=
inpu
t_requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight_requires_grad
if
bias_op
is
not
None
and
bias_op_ctx
.
requires_grad
:
bias_op_ctx
.
grad_input_quantizer
=
linear_op
.
get_grad_
out
put_quantizer
()
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
87e3e56e
...
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
import
torch
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.ops.basic
import
Add
InPlace
,
BasicLinear
,
Bias
from
transformer_engine.pytorch.ops.basic
import
Add
ExtraInput
,
BasicLinear
,
Bias
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
...
...
@@ -33,7 +33,7 @@ class ForwardLinearBiasAdd(FusedOperation):
*
,
linear
:
BasicLinear
,
bias
:
Optional
[
Bias
],
add
:
Add
InPlace
,
add
:
Add
ExtraInput
,
)
->
None
:
# Basic operations that comprise this fused operation
...
...
@@ -57,7 +57,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
prev_op_grad_
in
put_quantizer
:
Optional
[
Quantizer
],
prev_op_grad_
out
put_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
...
...
@@ -83,17 +83,12 @@ class ForwardLinearBiasAdd(FusedOperation):
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# FP8 metadata
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
input_quantizer
=
None
weight_quantizer
=
None
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
None
grad_output_quantizer
=
None
grad_input_quantizer
=
None
if
with_quantized_compute
:
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_input_quantizer
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_output_quantizer
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
# Get autocast dtype if needed
if
torch
.
is_autocast_enabled
():
...
...
@@ -122,18 +117,18 @@ class ForwardLinearBiasAdd(FusedOperation):
)
# Save state for backward pass
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input
_quantize
r
=
input_quantizer
linear_op_ctx
.
weigh
t_quantizer
=
weigh
t_quantizer
linear_op_ctx
.
grad_outpu
t_quantizer
=
grad_outpu
t_quantizer
linear_op_ctx
.
grad_
in
put_quantizer
=
grad_
in
put_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weigh
t_requires_grad
=
weigh
t_requires_grad
if
bias_op
is
not
None
:
bias_op
_ctx
.
with_quantized_compute
=
with_quantized_compute
bias_op_ctx
.
grad_input_quantizer
=
linear_op
.
get_grad_
in
put_quantizer
()
if
linear_op_ctx
.
requires_grad
:
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with
_quantize
d_compute
=
with_quantized_compute
linear_op_ctx
.
inpu
t_quantizer
=
inpu
t_quantizer
linear_op_ctx
.
weigh
t_quantizer
=
weigh
t_quantizer
linear_op_ctx
.
grad_
out
put_quantizer
=
grad_
out
put_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
inpu
t_requires_grad
=
inpu
t_requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight_requires_grad
if
bias_op
is
not
None
and
bias_op_ctx
.
requires_grad
:
bias_op_ctx
.
grad_input_quantizer
=
linear_op
.
get_grad_
out
put_quantizer
()
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
...
...
@@ -184,8 +179,10 @@ def fuse_forward_linear_bias_add(
continue
op
,
_
=
ops
[
0
]
# Check if next op is add in-place
if
not
isinstance
(
op
,
AddInPlace
):
# Check if next op is in-place add extra input
if
not
isinstance
(
op
,
AddExtraInput
):
continue
if
not
op
.
_in_place
:
continue
add
=
op
window
.
extend
(
ops
[:
1
])
...
...
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
0 → 100644
View file @
87e3e56e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused operation for forward GEMM + scale + add."""
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
import
torch
from
...fp8
import
FP8GlobalStateManager
from
..basic
import
AddExtraInput
,
BasicLinear
,
ConstantScale
from
..op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
class
ForwardLinearScaleAdd
(
FusedOperation
):
"""Fused forward GEMM + scale + add
Row tensor parallelism is not supported since that requires
communication immediately after the GEMM.
"""
def
__init__
(
self
,
*
,
linear
:
BasicLinear
,
scale
:
ConstantScale
,
add
:
AddExtraInput
,
)
->
None
:
super
().
__init__
((
linear
,
scale
,
add
))
def
fuser_forward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
prev_op_grad_output_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
# Get basic operations
linear_op
=
self
.
basic_ops
[
0
]
linear_op_ctx
=
basic_op_ctxs
[
0
]
scale_op
=
self
.
basic_ops
[
1
]
# Check which grads are required
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
# FP8 metadata
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
None
grad_output_quantizer
=
linear_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_output_quantizer
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
# Get extra input tensor for add operation
extra_input
=
basic_op_extra_inputs
[
2
][
0
]
# Get autocast dtype if needed
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
else
:
dtype
=
linear_op
.
weight
.
dtype
# Linear forward
output
,
x_local
,
w
=
BasicLinear
.
_functional_forward
(
input
=
input_
,
weight
=
linear_op
.
weight
,
alpha
=
scale_op
.
scale
,
dtype
=
dtype
,
out
=
extra_input
,
accumulate_into_out
=
True
,
tensor_parallel_mode
=
linear_op
.
tensor_parallel_mode
,
tensor_parallel_group
=
linear_op
.
tensor_parallel_group
,
sequence_parallel
=
linear_op
.
sequence_parallel
,
with_quantized_compute
=
with_quantized_compute
,
input_quantizer
=
input_quantizer
,
weight_quantizer
=
weight_quantizer
,
output_quantizer
=
output_quantizer
,
input_requires_grad
=
input_requires_grad
,
weight_requires_grad
=
weight_requires_grad
,
)
# Save state for backward pass
if
linear_op_ctx
.
requires_grad
:
linear_op_ctx
.
save_for_backward
(
x_local
,
w
)
linear_op_ctx
.
with_quantized_compute
=
with_quantized_compute
linear_op_ctx
.
input_quantizer
=
input_quantizer
linear_op_ctx
.
weight_quantizer
=
weight_quantizer
linear_op_ctx
.
grad_output_quantizer
=
grad_output_quantizer
linear_op_ctx
.
grad_input_quantizer
=
grad_input_quantizer
linear_op_ctx
.
dtype
=
dtype
linear_op_ctx
.
input_requires_grad
=
input_requires_grad
linear_op_ctx
.
weight_requires_grad
=
weight_requires_grad
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
def
fuse_forward_linear_scale_add
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Fuse forward GEMM + scale + add
Parameters
----------
ops: list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops: list of tuples
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out
=
[]
window
=
[]
while
len
(
ops
)
>=
3
:
out
.
extend
(
window
)
# Check if first op is linear
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
BasicLinear
):
continue
if
op
.
tensor_parallel_mode
==
"row"
:
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear
=
op
op
,
_
=
ops
[
0
]
# Check if next op is constant scale
if
not
isinstance
(
op
,
ConstantScale
):
continue
scale
=
op
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
op
,
_
=
ops
[
0
]
# Check if next op is in-place add extra input
if
not
isinstance
(
op
,
AddExtraInput
):
continue
if
not
op
.
_in_place
:
continue
add
=
op
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
ForwardLinearScaleAdd
(
linear
=
linear
,
scale
=
scale
,
add
=
add
,
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment