Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
44740c6c
Commit
44740c6c
authored
Jul 22, 2025
by
yuguo
Browse files
Merge commit '
7a9a0825
' of...
Merge commit '
7a9a0825
' of
https://github.com/NVIDIA/TransformerEngine
parents
8113d9e0
7a9a0825
Changes
162
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1007 additions
and
382 deletions
+1007
-382
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+107
-9
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+155
-37
transformer_engine/pytorch/onnx_extensions.py
transformer_engine/pytorch/onnx_extensions.py
+362
-0
transformer_engine/pytorch/ops/_common.py
transformer_engine/pytorch/ops/_common.py
+17
-77
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+19
-63
transformer_engine/pytorch/ops/basic/add_in_place.py
transformer_engine/pytorch/ops/basic/add_in_place.py
+5
-3
transformer_engine/pytorch/ops/basic/all_gather.py
transformer_engine/pytorch/ops/basic/all_gather.py
+5
-7
transformer_engine/pytorch/ops/basic/all_reduce.py
transformer_engine/pytorch/ops/basic/all_reduce.py
+5
-7
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+61
-65
transformer_engine/pytorch/ops/basic/bias.py
transformer_engine/pytorch/ops/basic/bias.py
+28
-7
transformer_engine/pytorch/ops/basic/identity.py
transformer_engine/pytorch/ops/basic/identity.py
+3
-2
transformer_engine/pytorch/ops/basic/l2normalization.py
transformer_engine/pytorch/ops/basic/l2normalization.py
+7
-14
transformer_engine/pytorch/ops/basic/layer_norm.py
transformer_engine/pytorch/ops/basic/layer_norm.py
+32
-39
transformer_engine/pytorch/ops/basic/make_extra_output.py
transformer_engine/pytorch/ops/basic/make_extra_output.py
+4
-3
transformer_engine/pytorch/ops/basic/quantize.py
transformer_engine/pytorch/ops/basic/quantize.py
+6
-5
transformer_engine/pytorch/ops/basic/reduce_scatter.py
transformer_engine/pytorch/ops/basic/reduce_scatter.py
+5
-7
transformer_engine/pytorch/ops/basic/reshape.py
transformer_engine/pytorch/ops/basic/reshape.py
+3
-2
transformer_engine/pytorch/ops/basic/rmsnorm.py
transformer_engine/pytorch/ops/basic/rmsnorm.py
+28
-35
transformer_engine/pytorch/ops/fused/__init__.py
transformer_engine/pytorch/ops/fused/__init__.py
+4
-0
transformer_engine/pytorch/ops/fused/backward_bias_activation.py
...rmer_engine/pytorch/ops/fused/backward_bias_activation.py
+151
-0
No files found.
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
44740c6c
...
...
@@ -78,6 +78,7 @@ from ..tensor.quantized_tensor import (
from
..cpp_extensions
import
(
general_gemm
,
)
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
...debug.pytorch.utils
import
any_feature_enabled
from
...debug.pytorch.debug_state
import
TEDebugState
...
...
@@ -86,16 +87,16 @@ __all__ = ["LayerNormMLP"]
def
_get_act_func_supported_list
(
recipe
:
Optional
[
Recipe
]
=
None
):
if
recipe
is
None
:
# bf16 (recipe is None):
[tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# bf16 (recipe is None):
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
tex
.
dbias_dgelu
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
tex
.
dbias_drelu
),
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"relu"
:
(
tex
.
relu
,
tex
.
drelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
"reglu"
:
(
tex
.
reglu
,
tex
.
dreglu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
tex
.
dbias_dqgelu
),
"qgelu"
:
(
tex
.
qgelu
,
tex
.
dqgelu
,
None
),
"qgeglu"
:
(
tex
.
qgeglu
,
tex
.
dqgeglu
,
None
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
tex
.
dbias_dsrelu
),
"srelu"
:
(
tex
.
srelu
,
tex
.
dsrelu
,
None
),
}
if
recipe
.
delayed
()
or
recipe
.
mxfp8
():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
...
...
@@ -553,8 +554,20 @@ class _LayerNormMLP(torch.autograd.Function):
)
if
fuse_wgrad_accumulation
:
ctx
.
fc1_main_grad
=
fc1_weight
.
main_grad
if
fc1_weight
.
requires_grad
else
None
ctx
.
fc2_main_grad
=
fc2_weight
.
main_grad
if
fc2_weight
.
requires_grad
else
None
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if
hasattr
(
fc1_weight
,
"__fsdp_param__"
)
and
hasattr
(
fc2_weight
,
"__fsdp_param__"
):
# MCore FSDP creates main_grad lazily before backward
ctx
.
fc1_main_grad_func
=
(
fc1_weight
.
get_main_grad
if
fc1_weight
.
requires_grad
else
lambda
:
None
)
ctx
.
fc2_main_grad_func
=
(
fc2_weight
.
get_main_grad
if
fc2_weight
.
requires_grad
else
lambda
:
None
)
else
:
ctx
.
fc1_main_grad_func
=
lambda
:
fc1_weight
.
main_grad
ctx
.
fc2_main_grad_func
=
lambda
:
fc2_weight
.
main_grad
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
...
...
@@ -654,14 +667,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad
=
(
ctx
.
fc1_main_grad
ctx
.
fc1_main_grad
_func
()
if
fc1_weight
is
not
None
and
ctx
.
fuse_wgrad_accumulation
and
ctx
.
fc1_weight_requires_grad
else
None
)
fc2_weight_main_grad
=
(
ctx
.
fc2_main_grad
ctx
.
fc2_main_grad
_func
()
if
origin_fc2_weight
is
not
None
and
ctx
.
fuse_wgrad_accumulation
and
ctx
.
fc2_weight_requires_grad
...
...
@@ -1727,6 +1740,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
)
debug
=
TEDebugState
.
debug_enabled
if
debug
:
self
.
_validate_name
()
...
...
@@ -1917,6 +1932,89 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer
,
)
def
onnx_forward
(
self
,
inp
:
torch
.
Tensor
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from
..export
import
onnx_layernorm
,
onnx_gemm
assert
not
TEDebugState
.
debug_enabled
,
"Debug mode is not supported in ONNX export"
assert_warmed_up
(
self
)
(
fc1_input_quantizer
,
fc1_weight_quantizer
,
fc2_input_quantizer
,
fc2_weight_quantizer
,
output_quantizer
,
*
_
,
)
=
self
.
_get_quantizers
(
False
)
inp_dtype
=
inp
.
dtype
fc1_weight
,
fc2_weight
=
self
.
_get_weight_tensors
()
fc1_bias
=
self
.
fc1_bias
if
self
.
use_bias
else
None
fc2_bias
=
self
.
fc2_bias
if
self
.
use_bias
else
None
# layernorm + fp8 cast
ln_out
,
ln_out_return
=
onnx_layernorm
(
inp
,
self
.
layer_norm_weight
,
self
.
layer_norm_bias
,
self
.
eps
,
self
.
normalization
,
self
.
zero_centered_gamma
,
inp_dtype
,
self
.
return_layernorm_output
,
fc1_input_quantizer
,
)
if
fc1_weight_quantizer
is
not
None
:
fc1_weight_q
=
fc1_weight_quantizer
.
onnx_quantize
(
fc1_weight
)
fc1_weight
=
fc1_weight_quantizer
.
onnx_dequantize
(
fc1_weight_q
)
fc1_weight
=
fc1_weight
.
to
(
inp_dtype
)
fc1_out
=
onnx_gemm
(
fc1_weight
,
ln_out
,
fc1_bias
)
fc1_out
=
fc1_out
.
to
(
torch
.
float32
)
# activation is computed in fp32
activation_map
=
{
"gelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"relu"
:
torch
.
nn
.
functional
.
relu
,
"geglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"reglu"
:
lambda
x
:
torch
.
nn
.
functional
.
relu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"swiglu"
:
lambda
x
:
torch
.
nn
.
functional
.
silu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"qgeglu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
.
chunk
(
2
,
-
1
)[
0
],
approximate
=
"tanh"
)
*
x
.
chunk
(
2
,
-
1
)[
1
],
"qgelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
"srelu"
:
torch
.
nn
.
functional
.
softplus
,
}
if
self
.
activation
not
in
activation_map
:
raise
ValueError
(
f
"Unsupported activation in onnx export:
{
self
.
activation
}
"
)
act_out
=
activation_map
[
self
.
activation
](
fc1_out
)
if
fc2_weight_quantizer
is
not
None
:
fc2_weight_q
=
fc2_weight_quantizer
.
onnx_quantize
(
fc2_weight
)
fc2_weight
=
fc2_weight_quantizer
.
onnx_dequantize
(
fc2_weight_q
)
fc2_weight
=
fc2_weight
.
to
(
inp_dtype
)
if
fc2_input_quantizer
is
not
None
:
act_out_q
=
fc2_input_quantizer
.
onnx_quantize
(
act_out
)
act_out
=
fc2_input_quantizer
.
onnx_dequantize
(
act_out_q
)
act_out
=
act_out
.
to
(
inp_dtype
)
fc2_out
=
onnx_gemm
(
fc2_weight
,
act_out
,
fc2_bias
)
if
output_quantizer
is
not
None
:
raise
NotImplementedError
(
"ONNX export of quantized output is not supported"
)
if
self
.
return_layernorm_output
:
if
self
.
return_bias
:
return
fc2_out
,
fc2_bias
.
to
(
inp_dtype
),
ln_out_return
return
fc2_out
,
ln_out_return
if
self
.
return_bias
:
return
fc2_out
,
fc2_bias
.
to
(
inp_dtype
)
return
fc2_out
def
_get_debug_quantizers
(
self
,
fp8_output
):
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
...
...
transformer_engine/pytorch/module/linear.py
View file @
44740c6c
...
...
@@ -68,6 +68,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantize
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...debug.pytorch.debug_state
import
TEDebugState
from
...debug.pytorch.utils
import
any_feature_enabled
...
...
@@ -117,6 +118,7 @@ class _Linear(torch.autograd.Function):
module
:
torch
.
nn
.
Module
,
skip_fp8_weight_update
:
bool
,
symmetric_ar_type
:
str
,
save_original_input
:
bool
=
False
,
debug
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
...
...
@@ -157,6 +159,11 @@ class _Linear(torch.autograd.Function):
own_quantized_input
=
False
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
if
save_original_input
:
assert
not
isinstance
(
input_quantizer
,
Float8Quantizer
),
"DelayedScaling recipe is not supported with save_original_input"
if
with_input_all_gather_nccl
or
ub_overlap_ag_fprop
:
# All-gather input tensor
# Cast local input tensor if needed
...
...
@@ -164,7 +171,9 @@ class _Linear(torch.autograd.Function):
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
not
isinstance
(
inputmat
,
QuantizedTensorBase
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
and
not
save_original_input
)
if
isinstance
(
input_quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
):
...
...
@@ -201,7 +210,9 @@ class _Linear(torch.autograd.Function):
else
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
and
not
save_original_input
)
inputmat
=
input_quantizer
(
inputmat
)
own_quantized_input
=
True
else
:
...
...
@@ -330,6 +341,9 @@ class _Linear(torch.autograd.Function):
# ------------------------------------------------------
if
is_grad_enabled
:
if
save_original_input
:
inputmat
=
inp
ctx
.
weight_quantizer
=
weight_quantizer
saved_inputmat
=
None
...
...
@@ -338,6 +352,7 @@ class _Linear(torch.autograd.Function):
)
if
backward_needs_input
:
if
not
save_original_input
:
if
own_quantized_input
and
isinstance
(
inputmat
,
QuantizedTensorBase
):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
...
...
@@ -398,7 +413,14 @@ class _Linear(torch.autograd.Function):
ctx
.
grad_output_quantizer
=
grad_output_quantizer
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
if
fuse_wgrad_accumulation
and
weight
.
requires_grad
:
ctx
.
main_grad
=
weight
.
main_grad
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if
hasattr
(
weight
,
"__fsdp_param__"
):
# MCore FSDP creates main_grad lazily before backward
ctx
.
main_grad_func
=
weight
.
get_main_grad
else
:
ctx
.
main_grad_func
=
lambda
:
weight
.
main_grad
ctx
.
debug
=
debug
ctx
.
cpu_offloading
=
cpu_offloading
...
...
@@ -454,7 +476,7 @@ class _Linear(torch.autograd.Function):
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad
=
(
ctx
.
main_grad
ctx
.
main_grad
_func
()
if
weight
is
not
None
and
ctx
.
fuse_wgrad_accumulation
and
ctx
.
requires_wgrad
else
None
)
...
...
@@ -550,6 +572,24 @@ class _Linear(torch.autograd.Function):
# --------------------------------------------------
inputmat_total
=
None
inputmat_total_work
=
None
if
ctx
.
requires_wgrad
:
input_is_quantized
=
isinstance
(
inputmat
,
QuantizedTensorBase
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
not
input_is_quantized
:
quantizer
=
ctx
.
input_quantizer
if
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)):
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
not
ctx
.
backward_input_needs_gather
,
)
else
:
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
inputmat
=
quantizer
(
inputmat
)
else
:
if
input_is_quantized
:
inputmat
=
inputmat
.
dequantize
(
dtype
=
ctx
.
activation_dtype
)
else
:
inputmat
=
cast_if_needed
(
inputmat
,
ctx
.
activation_dtype
)
if
ctx
.
backward_input_needs_gather
:
quantizer
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
...
...
@@ -894,6 +934,7 @@ class _Linear(torch.autograd.Function):
None
,
# module
None
,
# skip_fp8_weight_update
None
,
# symmetric_ar_type
None
,
# save_original_input
None
,
# debug
)
...
...
@@ -976,6 +1017,11 @@ class Linear(TransformerEngineBaseModule):
This can help in latency bound communication situations.
Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce
is used.
save_original_input : bool, default = `False`
If set to `True`, always saves the original input tensor rather than the
cast tensor. In some scenarios, the input tensor is used by multiple modules,
and saving the original input tensor may reduce the memory usage.
Cannot work with FP8 DelayedScaling recipe.
"""
def
__init__
(
...
...
@@ -1003,6 +1049,7 @@ class Linear(TransformerEngineBaseModule):
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
symmetric_ar_type
:
Optional
[
str
]
=
None
,
save_original_input
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -1017,6 +1064,7 @@ class Linear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
symmetric_ar_type
=
symmetric_ar_type
self
.
save_original_input
=
save_original_input
self
.
name
=
name
if
TEDebugState
.
debug_enabled
:
...
...
@@ -1275,6 +1323,9 @@ class Linear(TransformerEngineBaseModule):
first microbatch (since it is the first gradient being
produced)
"""
if
is_in_onnx_export_mode
():
return
self
.
onnx_forward
(
inp
,
fp8_output
)
debug
=
TEDebugState
.
debug_enabled
if
debug
:
self
.
_validate_name
()
...
...
@@ -1298,13 +1349,7 @@ class Linear(TransformerEngineBaseModule):
allow_non_contiguous
=
isinstance
(
inp
,
QuantizedTensor
),
)
as
inp
:
# Get concatenated weight and bias tensors
unfused_weights
=
self
.
_get_weight_tensors
()
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
else
:
bias_tensor
=
None
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
...
...
@@ -1370,6 +1415,7 @@ class Linear(TransformerEngineBaseModule):
self
,
skip_fp8_weight_update
,
self
.
symmetric_ar_type
,
self
.
save_original_input
,
debug
,
)
out
=
linear_fn
(
*
args
)
...
...
@@ -1417,6 +1463,95 @@ class Linear(TransformerEngineBaseModule):
for
name
,
q
in
zip
(
names
,
original_quantizers
)
)
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
"""Get the weight tensors of the module."""
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
self
.
fp8
:
if
len
(
unfused_weights
)
!=
1
:
raise
RuntimeError
(
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
return
unfused_weights
def
_get_weight_and_bias_tensors
(
self
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
# Get concatenated weight and bias tensors
unfused_weights
=
self
.
_get_weight_tensors
()
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
self
.
fp8
:
if
len
(
unfused_weights
)
!=
1
:
raise
RuntimeError
(
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
weight_tensor
=
noop_cat
(
unfused_weights
)
if
self
.
use_bias
:
bias_tensor
=
noop_cat
([
getattr
(
self
,
name
)
for
name
in
self
.
bias_names
])
else
:
bias_tensor
=
None
return
weight_tensor
,
bias_tensor
def
onnx_forward
(
self
,
inp
:
torch
.
Tensor
,
fp8_output
:
bool
,
)
->
torch
.
Tensor
:
"""
ONNX-compatible version of the forward function that provides numerical equivalence
while only using operations that have defined ONNX symbolic translations.
This simplified implementation is designed specifically for inference scenarios.
"""
from
..export
import
onnx_gemm
assert_warmed_up
(
self
)
assert
not
TEDebugState
.
debug_enabled
,
"Debug mode is not supported in ONNX export."
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
(
input_quantizer
,
weight_quantizer
,
output_quantizer
,
*
_
,
)
=
self
.
_get_quantizers
(
fp8_output
,
False
)
inp_dtype
=
inp
.
dtype
if
input_quantizer
is
not
None
:
inp_q
=
input_quantizer
.
onnx_quantize
(
inp
)
inp
=
input_quantizer
.
onnx_dequantize
(
inp_q
)
inp
=
inp
.
to
(
inp_dtype
)
if
weight_quantizer
is
not
None
:
weight_q
=
weight_quantizer
.
onnx_quantize
(
weight_tensor
)
weight_tensor
=
weight_quantizer
.
onnx_dequantize
(
weight_q
)
if
bias_tensor
is
not
None
:
bias_tensor
=
bias_tensor
.
to
(
inp_dtype
)
weight_tensor
=
weight_tensor
.
to
(
inp_dtype
)
if
self
.
apply_bias
:
output
=
onnx_gemm
(
weight_tensor
,
inp
,
bias_tensor
)
else
:
output
=
onnx_gemm
(
weight_tensor
,
inp
,
None
)
if
output_quantizer
is
not
None
:
raise
NotImplementedError
(
"ONNX export of quantized output is not supported"
)
if
self
.
return_bias
:
return
output
,
bias_tensor
return
output
def
_customize_quantizers_float8_current_scaling
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + linear."""
assert
(
...
...
@@ -1464,23 +1599,6 @@ class Linear(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
"""Get the weight tensors of the module."""
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
if
self
.
fp8
:
if
len
(
unfused_weights
)
!=
1
:
raise
RuntimeError
(
"Splitting QuantizedTensor into multiple params is not supported"
)
else
:
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights
=
[
w
.
dequantize
()
for
w
in
unfused_weights
]
return
unfused_weights
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
:
...
...
transformer_engine/pytorch/onnx_extensions.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
File containing torch.ops extensions and their corresponding ONNX symbolic functions.
Many transformer engine layers rely on custom calls from the transformer_engine_torch module, making ONNX export challenging because:
1. They often accept Python objects (quantizers), which ONNX does not support.
2. They are complex, incorporating fusions and precomputing certain values for backward passes—mechanisms unnecessary for ONNX export.
For these reasons, we introduce onnx_forward methods in each layer that are simpler and
primarily leverage torch operators with known ONNX symbolic functions.
These methods avoid fusions and backward pass precomputations.
The main considerations are quantization—which PyTorch does not natively support, so we need to implement onnx symbolic functions on our own.
Since ONNX does not yet support quantization, operators from TensorRT are employed.
The primary goal of ONNX export is to enable inference compatibility with TensorRT.
"""
from
typing
import
Tuple
import
math
import
torch
import
onnxscript
from
onnxscript
import
opset18
as
op
from
onnx
import
defs
import
transformer_engine_torch
as
tex
from
.tensor.float8_tensor
import
Float8Quantizer
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
.constants
import
MXFP8_BLOCK_SCALING_SIZE
from
.utils
import
round_up_to_nearest_multiple
from
.export
import
is_in_onnx_export_mode
trt_opset
=
onnxscript
.
values
.
Opset
(
"trt"
,
version
=
1
)
# opset from TensorRT which supports FP8 quantization
# ONNX GEMM for inference
def
onnx_gemm
(
weight
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""ONNX GEMM used for inference."""
reshaped_inp
=
inp
.
reshape
(
-
1
,
inp
.
shape
[
-
1
])
out
=
torch_onnx_gemm_inf_op
(
weight
,
reshaped_inp
,
bias
)
return
out
.
reshape
(
inp
.
shape
[:
-
1
]
+
(
-
1
,))
@
torch
.
library
.
custom_op
(
"tex::gemm_inf"
,
mutates_args
=
[])
def
torch_onnx_gemm_inf_op
(
weight
:
torch
.
Tensor
,
inp
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Gemm used for inference -- weight is transposed"""
out
=
inp
@
weight
.
T
if
bias
is
not
None
:
out
=
out
+
bias
return
out
@
torch_onnx_gemm_inf_op
.
register_fake
def
_
(
weight
,
inp
,
bias
):
"""Fake gemm used for inference."""
out
=
inp
@
weight
.
T
if
bias
is
not
None
:
out
=
out
+
bias
return
out
def
onnx_gemm_inf_symbolic
(
weight
:
onnxscript
.
onnx_types
.
TensorType
,
inp
:
onnxscript
.
onnx_types
.
TensorType
,
bias
:
onnxscript
.
onnx_types
.
TensorType
,
)
->
onnxscript
.
onnx_types
.
TensorType
:
"""Symbolic gemm used for inference."""
return
op
.
Gemm
(
inp
,
weight
,
bias
,
transA
=
0
,
transB
=
1
)
# ONNX FP8 Quantization
@
torch
.
library
.
custom_op
(
"tex::fp8_quantize"
,
mutates_args
=
[])
def
onnx_quantize_fp8_op
(
tensor
:
torch
.
Tensor
,
scale
:
float
)
->
torch
.
Tensor
:
"""Quantize to Float8Tensor used for inference."""
scale_tensor
=
torch
.
tensor
(
scale
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
amax_tensor
=
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
quantizer
=
Float8Quantizer
(
scale_tensor
,
amax_tensor
,
tex
.
DType
.
kFloat8E4M3
)
return
quantizer
.
quantize
(
tensor
).
_data
@
onnx_quantize_fp8_op
.
register_fake
def
_
(
tensor
,
*
_
):
"""Fake quantize to Float8Tensor used for inference."""
return
torch
.
empty
(
tensor
.
shape
,
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
)
def
onnx_quantize_fp8_symbolic
(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
scale
:
float
,
)
->
onnxscript
.
onnx_types
.
UINT8
:
"""Symbolic quantize used for inference."""
scale_inv
=
op
.
Constant
(
value_float
=
1
/
scale
)
return
TRT_FP8QuantizeLinear
(
tensor
,
scale_inv
)
# Define the schema for the custom operator
schema
=
defs
.
OpSchema
(
name
=
"TRT_FP8QuantizeLinear"
,
domain
=
"trt"
,
since_version
=
1
,
doc
=
"TRT FP8 Quantize Linear used for inference."
,
inputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(float)"
,
"Input tensor to quantize"
),
defs
.
OpSchema
.
FormalParameter
(
"scale"
,
"tensor(float)"
,
"Scale factor for quantization"
),
],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(uint8)"
,
"Quantized output tensor"
)],
)
TRT_FP8QuantizeLinear
=
onnxscript
.
values
.
Op
(
opset
=
trt_opset
,
name
=
"TRT_FP8QuantizeLinear"
,
op_schema
=
schema
)
# ONNX FP8 Dequantization
@
torch
.
library
.
custom_op
(
"tex::fp8_dequantize"
,
mutates_args
=
[])
def
onnx_dequantize_fp8_op
(
tensor
:
torch
.
Tensor
,
scale
:
float
)
->
torch
.
Tensor
:
"""Dequantize from Float8Tensor used for inference."""
scale_tensor
=
torch
.
tensor
(
scale
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
quantizer
=
Float8Quantizer
(
scale_tensor
,
torch
.
zeros
(
1
).
to
(
tensor
.
device
),
tex
.
DType
.
kFloat8E4M3
)
quantizer_tensor
=
quantizer
.
create_tensor_from_data
(
tensor
,
fake_dtype
=
torch
.
float32
)
return
quantizer_tensor
.
dequantize
()
@
onnx_dequantize_fp8_op
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
_
)
->
torch
.
Tensor
:
"""Fake dequantize from Float8Tensor used for inference."""
return
torch
.
empty
(
tensor
.
shape
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
def
onnx_dequantize_fp8_symbolic
(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
scale
:
float
)
->
onnxscript
.
onnx_types
.
TensorType
:
"""Symbolic dequantize from Float8Tensor used for inference."""
scale_inv
=
op
.
Constant
(
value_float
=
1
/
scale
)
return
TRT_FP8DequantizeLinear
(
tensor
,
scale_inv
)
schema
=
defs
.
OpSchema
(
name
=
"TRT_FP8DequantizeLinear"
,
domain
=
"trt"
,
since_version
=
1
,
doc
=
"TRT FP8 Dequantize Linear from Float8Tensor used for inference."
,
inputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(uint8)"
,
"Input tensor to dequantize"
),
defs
.
OpSchema
.
FormalParameter
(
"scale"
,
"tensor(float)"
,
"Scale factor for dequantization"
),
],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(float)"
,
"Dequantized output tensor"
)],
)
TRT_FP8DequantizeLinear
=
onnxscript
.
values
.
Op
(
opset
=
trt_opset
,
name
=
"TRT_FP8DequantizeLinear"
,
op_schema
=
schema
)
# ONNX MXFP8 Quantization
@
torch
.
library
.
custom_op
(
"tex::mxfp8_quantize"
,
mutates_args
=
[])
def
onnx_quantize_mxfp8_op
(
tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Quantize to MXFP8Tensor used for inference."""
quantizer
=
MXFP8Quantizer
(
tex
.
DType
.
kFloat8E4M3
)
quantized_tensor
=
quantizer
(
tensor
)
return
quantized_tensor
.
_rowwise_data
,
quantized_tensor
.
_rowwise_scale_inv
@
onnx_quantize_mxfp8_op
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
):
"""Fake quantize to MXFP8Tensor used for inference."""
mxfp8_scale_shape
=
[
round_up_to_nearest_multiple
(
math
.
prod
(
tensor
.
shape
[:
-
1
]),
128
),
round_up_to_nearest_multiple
(
tensor
.
shape
[
-
1
]
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
]
return
torch
.
empty
(
tensor
.
shape
,
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
),
torch
.
empty
(
mxfp8_scale_shape
,
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
)
def
onnx_quantize_mxfp8_symbolic
(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
)
->
Tuple
[
onnxscript
.
onnx_types
.
TensorType
,
onnxscript
.
onnx_types
.
TensorType
]:
"""Symbolic quantize to MXFP8Tensor used for inference."""
tensor_out
,
scale_inv_out
=
TRT_MXFP8QuantizeLinear
(
tensor
)
return
tensor_out
,
scale_inv_out
schema
=
defs
.
OpSchema
(
name
=
"TRT_MXFP8QuantizeLinear"
,
domain
=
"trt"
,
since_version
=
1
,
doc
=
"TRT MXFP8 Quantize Linear used for inference."
,
inputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(float)"
,
"Input tensor to quantize"
),
],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(uint8)"
,
"Quantized output tensor"
),
defs
.
OpSchema
.
FormalParameter
(
"scale_inv"
,
"tensor(uint8)"
,
"Scale factor for quantization"
),
],
)
TRT_MXFP8QuantizeLinear
=
onnxscript
.
values
.
Op
(
opset
=
trt_opset
,
name
=
"TRT_MXFP8QuantizeLinear"
,
op_schema
=
schema
)
# ONNX MXFP8 Dequantization
@
torch
.
library
.
custom_op
(
"tex::mxfp8_dequantize"
,
mutates_args
=
[])
def
onnx_dequantize_mxfp8_op
(
tensor
:
torch
.
Tensor
,
scale_inv
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Dequantize from MXFP8Tensor used for inference."""
quantizer
=
MXFP8Quantizer
(
tex
.
DType
.
kFloat8E4M3
)
quantizer_tensor
=
quantizer
.
create_tensor_from_data
(
tensor
,
scale_inv
,
fake_dtype
=
torch
.
float32
)
return
quantizer_tensor
.
dequantize
()
@
onnx_dequantize_mxfp8_op
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
_
):
"""Fake dequantize from MXFP8Tensor used for inference."""
return
torch
.
empty
(
tensor
.
shape
,
dtype
=
torch
.
float32
,
device
=
tensor
.
device
)
def
onnx_dequantize_mxfp8_symbolic
(
tensor
:
onnxscript
.
onnx_types
.
TensorType
,
scale_inv
:
onnxscript
.
onnx_types
.
TensorType
)
->
onnxscript
.
onnx_types
.
TensorType
:
"""Symbolic dequantize from MXFP8Tensor used for inference."""
return
TRT_MXFP8DequantizeLinear
(
tensor
,
scale_inv
)
schema
=
defs
.
OpSchema
(
name
=
"TRT_MXFP8DequantizeLinear"
,
domain
=
"trt"
,
since_version
=
1
,
doc
=
"TRT MXFP8 Dequantize Linear from MXFP8Tensor used for inference."
,
inputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"tensor"
,
"tensor(uint8)"
,
"Input tensor to dequantize"
),
defs
.
OpSchema
.
FormalParameter
(
"scale_inv"
,
"tensor(uint8)"
,
"Scale factor for dequantization"
),
],
outputs
=
[
defs
.
OpSchema
.
FormalParameter
(
"output"
,
"tensor(float)"
,
"Dequantized output tensor"
)],
)
TRT_MXFP8DequantizeLinear
=
onnxscript
.
values
.
Op
(
opset
=
trt_opset
,
name
=
"TRT_MXFP8DequantizeLinear"
,
op_schema
=
schema
)
# ONNX LayerNorm
@
torch
.
library
.
custom_op
(
"tex::layernorm"
,
mutates_args
=
[])
def
onnx_layernorm_op
(
inp
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
)
->
torch
.
Tensor
:
"""ONNX LayerNorm used for inference."""
model
=
tex
.
LayerNorm
(
inp
.
shape
[
1
],
eps
=
eps
)
model
.
weight
.
data
=
weight
model
.
bias
.
data
=
bias
return
model
(
inp
)
@
onnx_layernorm_op
.
register_fake
def
_
(
inp
,
*
_
):
"""Fake ONNX LayerNorm used for inference."""
return
inp
def
onnx_layernorm_symbolic
(
inp
:
onnxscript
.
onnx_types
.
TensorType
,
weight
:
onnxscript
.
onnx_types
.
TensorType
,
bias
:
onnxscript
.
onnx_types
.
TensorType
,
eps
:
float
,
)
->
onnxscript
.
onnx_types
.
TensorType
:
"""Symbolic ONNX LayerNorm used for inference."""
return
op
.
LayerNormalization
(
inp
,
weight
,
bias
,
epsilon
=
eps
)
# onnx layernorm helper function - handles layernorm with quantization
def
onnx_layernorm
(
inp
:
torch
.
Tensor
,
layer_norm_weight
:
torch
.
Tensor
,
layer_norm_bias
:
torch
.
Tensor
,
eps
:
float
,
normalization
:
str
,
zero_centered_gamma
:
bool
,
output_dtype
:
torch
.
dtype
,
return_layernorm_output
:
bool
,
input_quantizer
,
)
->
torch
.
Tensor
:
"""ONNX LayerNorm used for inference."""
ln_weight
=
layer_norm_weight
if
not
zero_centered_gamma
else
layer_norm_weight
+
1
ln_weight
=
ln_weight
.
to
(
inp
.
dtype
).
to
(
torch
.
float32
)
inp
=
inp
.
to
(
torch
.
float32
)
layer_norm_bias
=
(
layer_norm_bias
.
to
(
output_dtype
).
to
(
torch
.
float32
)
if
layer_norm_bias
is
not
None
else
None
)
if
normalization
==
"RMSNorm"
:
ln_out
=
torch
.
nn
.
functional
.
rms_norm
(
inp
,
inp
.
shape
[
-
1
:],
ln_weight
,
eps
)
else
:
ln_out
=
torch
.
nn
.
functional
.
layer_norm
(
inp
,
inp
.
shape
[
-
1
:],
ln_weight
,
layer_norm_bias
,
eps
)
ln_out_return
=
ln_out
if
input_quantizer
is
not
None
:
if
return_layernorm_output
:
# In case of return_layernorm_output, layernorm is not fused with fp8 cast,
# so we cast to input_dtype and then perform cast to fp8 if needed
ln_out
=
ln_out
.
to
(
output_dtype
).
to
(
torch
.
float32
)
ln_out_return
=
ln_out
elif
isinstance
(
input_quantizer
,
MXFP8Quantizer
):
# layernorm + mxfp8 quantizer behaves differently
ln_out
=
ln_out
.
to
(
output_dtype
).
to
(
torch
.
float32
)
ln_out_quantized
=
input_quantizer
.
onnx_quantize
(
ln_out
)
ln_out
=
input_quantizer
.
onnx_dequantize
(
ln_out_quantized
)
ln_out
=
ln_out
.
to
(
output_dtype
)
return
ln_out
,
ln_out_return
# utility functions
def
onnx_attention_mask_func
(
attention_scores
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Get attention mask without inp"""
assert
is_in_onnx_export_mode
()
return
attention_scores
.
masked_fill
(
attention_mask
,
-
10000.0
)
# This translation table should be passed to torch.onnx.export function
# using the custom_translation_table=te_translation_table option.
te_translation_table
=
{
torch
.
ops
.
tex
.
gemm_inf
.
default
:
onnx_gemm_inf_symbolic
,
torch
.
ops
.
tex
.
fp8_quantize
.
default
:
onnx_quantize_fp8_symbolic
,
torch
.
ops
.
tex
.
fp8_dequantize
.
default
:
onnx_dequantize_fp8_symbolic
,
torch
.
ops
.
tex
.
mxfp8_quantize
.
default
:
onnx_quantize_mxfp8_symbolic
,
torch
.
ops
.
tex
.
mxfp8_dequantize
.
default
:
onnx_dequantize_mxfp8_symbolic
,
torch
.
ops
.
tex
.
layernorm
.
default
:
onnx_layernorm_symbolic
,
}
transformer_engine/pytorch/ops/_common.py
View file @
44740c6c
...
...
@@ -5,7 +5,7 @@
"""Helper functions used in fusible operations."""
from
__future__
import
annotations
from
typing
import
Any
,
Iterable
,
Optional
from
typing
import
Optional
import
torch
...
...
@@ -13,84 +13,24 @@ from transformer_engine_torch import FP8TensorMeta
from
..
import
torch_version
from
..fp8
import
FP8GlobalStateManager
from
..tensor.float8_tensor
import
Float8Tensor
from
..utils
import
(
canonicalize_device
,
canonicalize_dtype
,
devices_match
,
)
def
is_float8_tensor
(
tensor
:
Any
)
->
bool
:
"""Check if object is a `Float8Tensor`"""
return
isinstance
(
tensor
,
Float8Tensor
)
def
convert_tensor
(
tensor
:
torch
.
Tensor
|
Float8Tensor
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
memory_format
:
torch
.
memory_format
=
torch
.
preserve_format
,
)
->
torch
.
Tensor
|
Float8Tensor
:
"""Convert tensor attributes, keeping same data if possible"""
# Default kwargs
if
device
is
None
:
device
=
tensor
.
device
device
=
canonicalize_device
(
device
)
if
dtype
is
None
:
dtype
=
tensor
.
dtype
dtype
=
canonicalize_dtype
(
dtype
)
# Make sure output is detached from autograd graph
tensor
=
tensor
.
detach
()
# Return immediately if tensor already has desired attributes
if
devices_match
(
device
,
tensor
.
device
)
and
dtype
==
tensor
.
dtype
:
if
memory_format
==
torch
.
preserve_format
or
tensor
.
is_contiguous
(
memory_format
=
memory_format
):
return
tensor
from
..tensor.quantized_tensor
import
QuantizedTensorBase
from
..utils
import
canonicalize_dtype
# Convert FP8 tensor
if
is_float8_tensor
(
tensor
):
data
=
tensor
.
_data
if
not
devices_match
(
device
,
data
.
device
):
data
=
data
.
to
(
device
=
device
)
if
memory_format
!=
torch
.
preserve_format
and
not
data
.
is_contiguous
(
memory_format
=
memory_format
):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
data
=
data
.
contiguous
(
memory_format
=
memory_format
)
out
=
Float8Tensor
.
make_like
(
tensor
,
dtype
=
dtype
)
out
.
data
=
data
return
out
# Convert standard PyTorch tensor
tensor
=
tensor
.
to
(
device
=
device
,
dtype
=
dtype
)
if
memory_format
!=
torch
.
preserve_format
and
not
tensor
.
is_contiguous
(
memory_format
=
memory_format
):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
tensor
=
tensor
.
contiguous
(
memory_format
=
memory_format
)
return
tensor
def
is_quantized_tensor
(
tensor
:
torch
.
Tensor
|
QuantizedTensorBase
)
->
bool
:
"""Check if tensor is a quantized tensor"""
return
isinstance
(
tensor
,
QuantizedTensorBase
)
def
reshape
(
tensor
:
torch
.
Tensor
|
Float8Tensor
,
shape
:
Iterable
[
int
],
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
torch
.
Tensor
|
Float8Tensor
:
"""Reshape tensor, keeping same data if possible"""
tensor
=
convert_tensor
(
tensor
,
device
=
device
,
dtype
=
dtype
,
memory_format
=
torch
.
contiguous_format
,
)
return
tensor
.
reshape
(
*
shape
)
def
maybe_dequantize
(
tensor
:
torch
.
Tensor
|
QuantizedTensorBase
,
dtype
:
torch
.
dtype
|
None
=
None
)
->
torch
.
Tensor
:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if
is_quantized_tensor
(
tensor
):
return
tensor
.
dequantize
(
dtype
=
dtype
)
if
dtype
is
not
None
and
tensor
.
dtype
!=
dtype
:
return
tensor
.
to
(
dtype
)
return
tensor
def
maybe_autocast_dtype
(
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
44740c6c
...
...
@@ -12,11 +12,10 @@ import torch
import
transformer_engine_torch
as
tex
from
...fp8
import
FP8GlobalStateManager
from
...tensor
import
QuantizedTensor
from
...tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
...utils
import
clear_tensor_data
,
devices_match
from
...tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Quantizer
from
...utils
import
clear_tensor_data
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
reshap
e
from
.._common
import
maybe_dequantiz
e
class
_ActivationOperation
(
BasicOperation
,
metaclass
=
abc
.
ABCMeta
):
...
...
@@ -72,8 +71,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
# Compute dtype
...
...
@@ -86,35 +85,16 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
raise
RuntimeError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
# Check input tensor
x
=
input_
if
isinstance
(
x
,
QuantizedTensor
):
x
=
x
.
dequantize
()
if
x
.
device
.
type
!=
"cuda"
:
x
=
x
.
cuda
()
if
x
.
dtype
!=
dtype
:
x
=
x
.
to
(
dtype
=
dtype
)
if
not
x
.
is_contiguous
():
x
=
x
.
contiguous
()
x
=
maybe_dequantize
(
input_
.
contiguous
(),
dtype
)
# Check if quantized compute is enabled
quantized_compute
_enabled
=
FP8GlobalStateManager
.
is_fp8_enabled
()
with_
quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
quantizer
=
None
if
(
quantized_compute_enabled
and
next_op
is
not
None
and
next_op
.
num_quantizers
(
"forward"
)
>
0
):
quantizer
=
next_op
.
get_quantizer
(
"forward"
,
0
)
if
with_quantized_compute
:
quantizer
=
next_op_input_quantizer
# Launch kernel
y
=
self
.
_activation_forward_impl
(
reshape
(
x
,
(
-
1
,
x
.
size
(
-
1
))),
quantizer
,
)
# Check output tensor
if
y
.
dim
()
!=
x
.
dim
():
y
=
y
.
reshape
(
list
(
x
.
shape
[:
-
1
])
+
[
-
1
])
y
=
self
.
_activation_forward_impl
(
x
,
quantizer
)
# Quantize input to FP8 before caching if needed
if
self
.
cache_quantized_input
:
...
...
@@ -123,10 +103,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x
=
input_quantizer
(
x
)
# Save state for backward pass
ctx
.
save_for_backward
(
x
.
detach
()
)
ctx
.
quantized_compute
_enabled
=
quantized_compute
_enabled
ctx
.
save_for_backward
(
x
)
ctx
.
with_
quantized_compute
=
with_
quantized_compute
ctx
.
dtype
=
dtype
ctx
.
prev_op
=
prev_op
ctx
.
prev_op
_grad_input_quantizer
=
prev_op_grad_input_quantizer
return
y
...
...
@@ -140,44 +120,20 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
(
x
,)
=
ctx
.
saved_tensors
# Check input tensor
if
isinstance
(
x
,
QuantizedTensor
):
x
=
x
.
dequantize
(
dtype
=
ctx
.
dtype
)
elif
x
.
dtype
!=
ctx
.
dtype
:
x
=
x
.
to
(
dtype
=
ctx
.
dtype
)
if
not
x
.
is_contiguous
():
x
=
x
.
contiguous
()
x
=
maybe_dequantize
(
x
.
contiguous
(),
ctx
.
dtype
)
# Check grad output tensor
dy
=
grad_output
if
isinstance
(
dy
,
QuantizedTensor
):
dy
=
dy
.
dequantize
(
dtype
=
ctx
.
dtype
)
if
not
devices_match
(
dy
.
device
,
x
.
device
)
or
dy
.
dtype
!=
x
.
dtype
:
dy
=
dy
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
if
not
dy
.
is_contiguous
():
dy
=
dy
.
contiguous
()
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
x
.
dtype
)
# Check if quantized compute is enabled
quantizer
=
None
if
(
ctx
.
quantized_compute_enabled
and
ctx
.
prev_op
is
not
None
and
ctx
.
prev_op
.
num_quantizers
(
"backward"
)
>
0
):
quantizer
=
ctx
.
prev_op
.
get_quantizer
(
"backward"
,
0
)
if
ctx
.
with_quantized_compute
:
quantizer
=
ctx
.
prev_op_grad_input_quantizer
# Launch kernel
dx
=
self
.
_activation_backward_impl
(
reshape
(
dy
,
(
-
1
,
dy
.
size
(
-
1
))),
reshape
(
x
,
(
-
1
,
x
.
size
(
-
1
))),
quantizer
,
)
# Check grad input tensor
if
dx
.
size
()
!=
x
.
size
():
dx
=
dx
.
reshape
(
x
.
size
())
dx
=
self
.
_activation_backward_impl
(
dy
,
x
,
quantizer
)
# Clear input tensor if possible
if
ctx
.
prev_op
is
not
None
:
clear_tensor_data
(
x
)
return
dx
,
()
...
...
transformer_engine/pytorch/ops/basic/add_in_place.py
View file @
44740c6c
...
...
@@ -15,6 +15,8 @@ from transformer_engine.pytorch.ops.op import (
OperationContext
,
)
from
transformer_engine.pytorch.tensor
import
Quantizer
class
AddInPlace
(
BasicOperation
):
"""Add in-place
...
...
@@ -57,8 +59,8 @@ class AddInPlace(BasicOperation):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
basic_op_prev_ops
:
list
[
Optional
[
BasicOperation
]
],
basic_op_next_ops
:
list
[
Optional
[
BasicOperation
]
],
prev_op_grad_input_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
output
=
basic_op_extra_inputs
[
0
][
0
].
detach
()
...
...
@@ -76,4 +78,4 @@ class AddInPlace(BasicOperation):
Iterable
[
Iterable
[
Optional
[
torch
.
Tensor
]]],
Iterable
[
Iterable
[
Optional
[
torch
.
Tensor
]]],
]:
return
grad_output
,
[],
[(
grad_output
,)]
return
grad_output
,
[
()
],
[(
grad_output
,)]
transformer_engine/pytorch/ops/basic/all_gather.py
View file @
44740c6c
...
...
@@ -10,8 +10,9 @@ from typing import Optional
import
torch
from
...distributed
import
gather_along_first_dim
from
..
.tensor
import
QuantizedTensor
from
..
_common
import
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
from
...tensor
import
Quantizer
class
AllGather
(
BasicOperation
):
...
...
@@ -39,8 +40,8 @@ class AllGather(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
out
:
torch
.
Tensor
if
self
.
process_group_size
==
1
:
...
...
@@ -71,10 +72,7 @@ class AllGather(BasicOperation):
input_dims
[
0
]
//=
self
.
process_group_size
# Check output gradient tensor
dy
=
grad_output
if
isinstance
(
dy
,
QuantizedTensor
):
dy
=
dy
.
dequantize
()
dy
=
dy
.
contiguous
()
dy
=
maybe_dequantize
(
grad_output
.
contiguous
())
# Perform reduce-scatter
dx
=
torch
.
empty
(
input_dims
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
...
...
transformer_engine/pytorch/ops/basic/all_reduce.py
View file @
44740c6c
...
...
@@ -9,8 +9,9 @@ from typing import Optional
import
torch
from
..
.tensor
import
QuantizedTensor
from
..
_common
import
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
from
...tensor
import
Quantizer
class
AllReduce
(
BasicOperation
):
...
...
@@ -41,8 +42,8 @@ class AllReduce(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
# Trivial case
...
...
@@ -50,10 +51,7 @@ class AllReduce(BasicOperation):
return
input_
# Perform all-reduce
x
=
input_
if
isinstance
(
x
,
QuantizedTensor
):
x
=
x
.
dequantize
()
x
=
x
.
contiguous
()
x
=
maybe_dequantize
(
input_
.
contiguous
())
torch
.
distributed
.
all_reduce
(
x
,
group
=
self
.
process_group
)
return
x
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
44740c6c
...
...
@@ -19,20 +19,21 @@ from ...distributed import (
gather_along_first_dim
,
reduce_scatter_along_first_dim
,
)
from
...fp8
import
FP8GlobalStateManager
from
...fp8
import
FP8GlobalStateManager
,
Recipe
from
...module.base
import
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
from
...tensor
import
Quantizer
,
QuantizedTensor
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
...tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
(
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
from
...utils
import
(
canonicalize_device
,
canonicalize_dtype
,
clear_tensor_data
,
devices_match
,
)
from
...utils
import
clear_tensor_data
def
_wait_async
(
handle
:
Optional
[
Any
])
->
None
:
...
...
@@ -271,7 +272,7 @@ class BasicLinear(BasicOperation):
device
=
canonicalize_device
(
None
)
# Allocate buffer if needed
if
is
instance
(
weight
,
Q
uantized
T
ensor
):
if
is
_q
uantized
_t
ensor
(
weight
):
weight
=
torch
.
empty
(
weight
.
size
(),
dtype
=
weight
.
dtype
,
...
...
@@ -302,8 +303,12 @@ class BasicLinear(BasicOperation):
weight
=
torch
.
nn
.
Parameter
(
weight
)
self
.
weight
=
weight
def
pre_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_forward
(
*
args
,
**
kwargs
)
def
pre_first_forward
(
self
,
*
,
recipe
:
Optional
[
Recipe
],
)
->
None
:
super
().
pre_first_forward
(
recipe
=
recipe
)
# Initialize weights if needed
weight
=
self
.
weight
...
...
@@ -312,20 +317,17 @@ class BasicLinear(BasicOperation):
weight
=
self
.
weight
# Configure quantizers
if
FP8GlobalStateManager
.
is_fp8_enabled
()
:
if
recipe
is
not
None
:
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
# Specify required tensor formats
is_grad_enabled
=
torch
.
is_grad_enabled
()
weight_requires_grad
=
is_grad_enabled
and
weight
.
requires_grad
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
is_grad_enabled
)
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
input_quantizer
.
internal
=
True
weight_quantizer
.
internal
=
True
grad_output_quantizer
.
internal
=
True
# Recipe-specific configuration
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
if
any
(
not
isinstance
(
q
,
Float8CurrentScalingQuantizer
)
...
...
@@ -390,7 +392,7 @@ class BasicLinear(BasicOperation):
Bias tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default =
default dtype
dtype: torch.dtype, default =
infer from out or weight
Tensor datatype
out: torch.Tensor, optional
Output tensor
...
...
@@ -437,8 +439,14 @@ class BasicLinear(BasicOperation):
# Check datatype
if
dtype
is
None
:
dtype
=
weight
.
dtype
if
out
is
None
else
out
.
dtype
dtype
=
canonicalize_dtype
(
dtype
)
if
out
is
not
None
and
isinstance
(
out
,
torch
.
Tensor
):
dtype
=
out
.
dtype
elif
weight
is
not
None
and
isinstance
(
out
,
torch
.
Tensor
):
dtype
=
weight
.
dtype
else
:
raise
ValueError
(
"Could not infer dtype from weight nor out and dtype was not provided"
)
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
raise
ValueError
(
f
"Supported dtypes are float32, float16, bfloat16 (got
{
dtype
}
)"
)
if
out
is
not
None
and
out
.
dtype
!=
dtype
:
...
...
@@ -462,14 +470,12 @@ class BasicLinear(BasicOperation):
quantizer
=
input_quantizer
,
)
else
:
if
not
is
instance
(
x_local
,
Q
uantized
T
ensor
):
if
not
is
_q
uantized
_t
ensor
(
x_local
):
x_local
=
input_quantizer
(
x_local
)
x
=
x_local
else
:
if
isinstance
(
x_local
,
QuantizedTensor
):
x_local
=
x_local
.
dequantize
()
if
x_local
.
dtype
!=
dtype
:
x_local
=
x_local
.
to
(
dtype
=
dtype
)
x_local
=
maybe_dequantize
(
x_local
,
dtype
)
if
with_x_all_gather
:
x
,
x_async
=
gather_along_first_dim
(
x_local
,
...
...
@@ -481,16 +487,13 @@ class BasicLinear(BasicOperation):
# Check weight tensor
w
=
weight
w_is_quantized
=
isinstance
(
w
,
QuantizedTensor
)
if
with_quantized_compute
and
not
w_is_quantized
:
if
not
with_quantized_compute
:
w
=
maybe_dequantize
(
w
,
dtype
)
elif
with_quantized_compute
and
not
is_quantized_tensor
(
w
):
if
weight_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for weight tensor"
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
input_requires_grad
)
w
=
weight_quantizer
(
w
)
elif
not
with_quantized_compute
and
w_is_quantized
:
w
=
w
.
dequantize
()
if
not
with_quantized_compute
and
w
.
dtype
!=
dtype
:
w
=
w
.
to
(
dtype
=
dtype
)
# Check output tensor
y
=
out
...
...
@@ -499,7 +502,7 @@ class BasicLinear(BasicOperation):
output_quantizer
=
None
if
tensor_parallel_mode
==
"row"
:
output_quantizer
=
None
elif
is
instance
(
y
,
Q
uantized
T
ensor
):
elif
is
_q
uantized
_t
ensor
(
y
):
if
not
with_quantized_compute
:
raise
ValueError
(
"Output tensor is quantized, but quantized compute is not enabled"
)
if
tensor_parallel_mode
==
"row"
:
...
...
@@ -564,18 +567,14 @@ class BasicLinear(BasicOperation):
# Prepare weight tensor for backward pass
if
input_requires_grad
:
if
w
is
not
weight
and
with_quantized_compute
and
is
instance
(
w
,
Q
uantized
T
ensor
):
if
w
is
not
weight
and
with_quantized_compute
and
is
_q
uantized
_t
ensor
(
w
):
w
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
w
=
None
# Prepare input tensor for backward pass
if
weight_requires_grad
:
if
x_local
is
input
:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local
=
x_local
.
detach
()
if
with_quantized_compute
and
isinstance
(
x_local
,
QuantizedTensor
):
if
with_quantized_compute
and
is_quantized_tensor
(
x_local
):
if
not
(
isinstance
(
x_local
,
Float8TensorBase
)
and
with_x_all_gather
):
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
...
...
@@ -668,9 +667,9 @@ class BasicLinear(BasicOperation):
# Check datatype
if
dtype
is
None
:
if
weight
is
not
None
:
if
isinstance
(
weight
,
torch
.
Tensor
)
:
dtype
=
weight
.
dtype
el
se
:
el
if
isinstance
(
grad_output
,
torch
.
Tensor
)
:
dtype
=
grad_output
.
dtype
dtype
=
canonicalize_dtype
(
dtype
)
if
dtype
not
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
...
...
@@ -696,14 +695,17 @@ class BasicLinear(BasicOperation):
quantizer
=
grad_output_quantizer
,
)
else
:
if
not
is
instance
(
dy_local
,
Q
uantized
T
ensor
):
if
not
is
_q
uantized
_t
ensor
(
dy_local
):
dy_local
=
grad_output_quantizer
(
dy_local
)
else
:
dy_local
.
update_usage
(
rowwise_usage
=
input_requires_grad
,
columnwise_usage
=
weight_requires_grad
,
)
dy
=
dy_local
else
:
if
isinstance
(
dy_local
,
QuantizedTensor
):
dy_local
=
dy_local
.
dequantize
()
if
dy_local
.
dtype
!=
dtype
:
dy_local
=
dy_local
.
to
(
dtype
=
dtype
)
dy_local
=
maybe_dequantize
(
dy_local
,
dtype
)
if
with_dy_all_gather
:
dy
,
dy_async
=
gather_along_first_dim
(
dy_local
,
...
...
@@ -733,16 +735,14 @@ class BasicLinear(BasicOperation):
quantizer
=
input_quantizer
,
)
else
:
if
is
instance
(
x_local
,
Q
uantized
T
ensor
):
if
is
_q
uantized
_t
ensor
(
x_local
):
x_local
.
update_usage
(
columnwise_usage
=
True
)
else
:
x_local
=
input_quantizer
(
x_local
)
x
=
x_local
else
:
if
isinstance
(
x_local
,
QuantizedTensor
):
x_local
=
x_local
.
dequantize
()
if
x_local
.
dtype
!=
dtype
:
x_local
=
x_local
.
to
(
dtype
=
dtype
)
x_local
=
maybe_dequantize
(
x_local
,
dtype
)
if
with_x_all_gather
:
x
,
x_async
=
gather_along_first_dim
(
x_local
,
...
...
@@ -761,9 +761,8 @@ class BasicLinear(BasicOperation):
if
weight
is
None
:
raise
ValueError
(
"Weight tensor is required to compute input grad"
)
w
=
weight
w_is_quantized
=
isinstance
(
w
,
QuantizedTensor
)
if
with_quantized_compute
:
if
w_
is_quantized
:
if
is_quantized
_tensor
(
w
)
:
w
.
update_usage
(
columnwise_usage
=
True
)
else
:
if
weight_quantizer
is
None
:
...
...
@@ -771,10 +770,7 @@ class BasicLinear(BasicOperation):
weight_quantizer
.
set_usage
(
columnwise
=
True
)
w
=
weight_quantizer
(
w
)
else
:
if
w_is_quantized
:
w
=
w
.
dequantize
(
dtype
=
dtype
)
elif
w
.
dtype
!=
dtype
:
w
=
w
.
to
(
dtype
=
dtype
)
w
=
maybe_dequantize
(
w
,
dtype
)
# Synchronize tensor-parallel communication
_wait_async
(
dy_async
)
...
...
@@ -787,7 +783,7 @@ class BasicLinear(BasicOperation):
grad_input_quantizer
=
None
if
tensor_parallel_mode
==
"column"
:
grad_input_quantizer
=
None
elif
is
instance
(
dx
,
Q
uantized
T
ensor
):
elif
is
_q
uantized
_t
ensor
(
dx
):
if
not
with_quantized_compute
:
raise
ValueError
(
"Grad input tensor is quantized, but quantized compute is not enabled"
...
...
@@ -898,12 +894,12 @@ class BasicLinear(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
# Check which grads are required
input_requires_grad
=
ctx
.
requires_grad
and
input_
.
requires_grad
input_requires_grad
=
ctx
.
requires_grad
weight_requires_grad
=
ctx
.
requires_grad
and
self
.
weight
.
requires_grad
# FP8 metadata
...
...
@@ -918,11 +914,9 @@ class BasicLinear(BasicOperation):
# Get quantizers
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
if
next_op
is
not
None
and
next_op
.
num_quantizers
(
"forward"
)
>
0
:
output_quantizer
=
next_op
.
get_quantizer
(
"forward"
,
0
)
output_quantizer
=
next_op_input_quantizer
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
if
prev_op
is
not
None
and
prev_op
.
num_quantizers
(
"backward"
)
>
0
:
grad_input_quantizer
=
prev_op
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_input_quantizer
# Configure quantizers
# Note: We cache the quantized input for backward pass,
...
...
@@ -931,9 +925,10 @@ class BasicLinear(BasicOperation):
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# Get autocast dtype if needed
dtype
=
None
if
torch
.
is_autocast_enabled
():
dtype
=
torch
.
get_autocast_dtype
(
"cuda"
)
else
:
dtype
=
self
.
weight
.
dtype
# Linear forward
output
,
x_local
,
w
=
BasicLinear
.
_functional_forward
(
...
...
@@ -961,7 +956,6 @@ class BasicLinear(BasicOperation):
ctx
.
dtype
=
dtype
ctx
.
input_requires_grad
=
input_requires_grad
ctx
.
weight_requires_grad
=
weight_requires_grad
ctx
.
has_prev_op
=
prev_op
is
not
None
return
output
...
...
@@ -978,6 +972,9 @@ class BasicLinear(BasicOperation):
accumulate_into_main_grad
=
self
.
_accumulate_into_main_grad
grad_weight
=
None
if
ctx
.
weight_requires_grad
and
accumulate_into_main_grad
:
if
hasattr
(
self
.
weight
,
"__fsdp_param__"
):
self
.
weight
.
main_grad
=
self
.
weight
.
get_main_grad
()
if
not
hasattr
(
self
.
weight
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
...
...
@@ -1009,7 +1006,6 @@ class BasicLinear(BasicOperation):
)
# Clear input tensor if possible
if
ctx
.
has_prev_op
:
clear_tensor_data
(
x_local
)
if
accumulate_into_main_grad
:
...
...
transformer_engine/pytorch/ops/basic/bias.py
View file @
44740c6c
...
...
@@ -9,14 +9,17 @@ from typing import Optional
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
OperationContext
,
)
from
..
_common
import
(
from
..
.utils
import
(
canonicalize_device
,
canonicalize_dtype
,
)
from
...fp8
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
class
Bias
(
BasicOperation
):
...
...
@@ -111,8 +114,8 @@ class Bias(BasicOperation):
bias
=
torch
.
nn
.
Parameter
(
bias
)
self
.
bias
=
bias
def
pre_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_forward
(
*
args
,
**
kwargs
)
def
pre_
first_
forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_
first_
forward
(
*
args
,
**
kwargs
)
if
self
.
bias
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
...
...
@@ -120,11 +123,25 @@ class Bias(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
x
=
input_
b
=
self
.
bias
.
reshape
([
1
]
*
(
x
.
dim
()
-
1
)
+
[
self
.
local_size
])
b
=
self
.
bias
.
view
([
1
]
*
(
x
.
dim
()
-
1
)
+
[
self
.
local_size
])
# Check if backward pass is needed
requires_grad
=
ctx
.
requires_grad
# Check if previous op quantizes its output's gradient
grad_input_quantizer
=
None
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
grad_input_quantizer
=
prev_op_grad_input_quantizer
if
requires_grad
:
ctx
.
with_quantized_compute
=
with_quantized_compute
ctx
.
grad_input_quantizer
=
grad_input_quantizer
return
x
+
b
def
op_backward
(
...
...
@@ -134,6 +151,10 @@ class Bias(BasicOperation):
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
dy
=
grad_output
if
dy
.
dim
()
>
1
:
quantizer
=
ctx
.
grad_input_quantizer
if
ctx
.
with_quantized_compute
and
quantizer
is
not
None
:
db
,
dy
=
tex
.
bgrad_quantize
(
dy
,
quantizer
)
else
:
db
=
dy
.
sum
(
tuple
(
range
(
dy
.
dim
()
-
1
)))
else
:
db
=
dy
...
...
transformer_engine/pytorch/ops/basic/identity.py
View file @
44740c6c
...
...
@@ -13,6 +13,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
class
Identity
(
BasicOperation
):
...
...
@@ -22,8 +23,8 @@ class Identity(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
return
input_
...
...
transformer_engine/pytorch/ops/basic/l2normalization.py
View file @
44740c6c
...
...
@@ -9,8 +9,8 @@ from typing import Optional
import
torch
from
...tensor
import
QuantizedTensor
from
...utils
import
clear_tensor_data
from
.._common
import
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
from
...jit
import
(
l2normalization_fused
,
...
...
@@ -19,6 +19,7 @@ from ...jit import (
set_jit_fusion_options
,
warmup_jit_l2normalization_all_dtypes
,
)
from
...tensor
import
Quantizer
class
L2Normalization
(
BasicOperation
):
...
...
@@ -73,14 +74,11 @@ class L2Normalization(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
# Use input directly - torch.compile can handle multi-dimensional tensors
x
=
input_
if
isinstance
(
x
,
QuantizedTensor
):
x
=
x
.
dequantize
()
x
=
maybe_dequantize
(
input_
)
# Check if backward pass is needed
requires_grad
=
ctx
.
requires_grad
...
...
@@ -98,7 +96,6 @@ class L2Normalization(BasicOperation):
# Save state for backward pass
if
requires_grad
:
ctx
.
save_for_backward
(
x
,
rsqrt_norm
)
ctx
.
has_prev_op
=
prev_op
is
not
None
return
y
...
...
@@ -111,16 +108,12 @@ class L2Normalization(BasicOperation):
# Saved tensors from forward pass
x
,
rsqrt_norm
=
ctx
.
saved_tensors
dy
=
grad_output
if
isinstance
(
dy
,
QuantizedTensor
):
dy
=
dy
.
dequantize
()
dy
=
maybe_dequantize
(
grad_output
)
# Compute L2 norm backward pass using fused implementation
dx
=
l2normalization_backward_fused
(
dy
,
x
,
rsqrt_norm
,
self
.
eps
)
# Clear saved tensors if possible
if
ctx
.
has_prev_op
:
clear_tensor_data
(
x
)
clear_tensor_data
(
rsqrt_norm
)
...
...
transformer_engine/pytorch/ops/basic/layer_norm.py
View file @
44740c6c
...
...
@@ -15,7 +15,6 @@ import torch
from
transformer_engine_torch
import
layernorm_bwd
,
layernorm_fwd
from
...fp8
import
FP8GlobalStateManager
from
...constants
import
TE_DType
from
...tensor
import
QuantizedTensor
from
...utils
import
(
canonicalize_device
,
canonicalize_dtype
,
...
...
@@ -23,7 +22,9 @@ from ...utils import (
devices_match
,
)
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_autocast_dtype
,
reshape
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
class
LayerNorm
(
BasicOperation
):
...
...
@@ -167,8 +168,8 @@ class LayerNorm(BasicOperation):
self
.
weight
=
weight
self
.
bias
=
bias
def
pre_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_forward
(
*
args
,
**
kwargs
)
def
pre_
first_
forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_
first_
forward
(
*
args
,
**
kwargs
)
if
self
.
weight
.
device
.
type
==
"meta"
or
self
.
bias
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
...
...
@@ -176,9 +177,11 @@ class LayerNorm(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
if
is_in_onnx_export_mode
():
return
self
.
op_onnx_forward
(
input_
)
# Check tensor dims
weight
=
self
.
weight
...
...
@@ -192,31 +195,19 @@ class LayerNorm(BasicOperation):
# Check input tensors
inner_dim
=
math
.
prod
(
weight_dims
)
device
=
weight
.
device
if
device
.
type
!=
"cuda"
:
device
=
canonicalize_device
(
None
)
dtype
=
maybe_autocast_dtype
(
default_dtype
=
weight
.
dtype
)
x
=
reshape
(
input_
,
(
-
1
,
inner_dim
),
device
=
device
,
dtype
=
dtype
)
w
=
reshape
(
self
.
weight
,
(
inner_dim
,),
device
=
device
,
dtype
=
dtype
)
b
=
reshape
(
self
.
bias
,
(
inner_dim
,),
device
=
device
,
dtype
=
dtype
)
if
isinstance
(
x
,
QuantizedTensor
):
x
=
x
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensor
):
w
=
w
.
dequantize
()
if
isinstance
(
b
,
QuantizedTensor
):
b
=
b
.
dequantize
()
x
=
maybe_dequantize
(
input_
.
contiguous
(),
dtype
).
view
((
-
1
,
inner_dim
))
w
=
maybe_dequantize
(
self
.
weight
,
dtype
).
view
((
inner_dim
,))
b
=
maybe_dequantize
(
self
.
bias
,
dtype
).
view
((
inner_dim
,))
# Check if backward pass is needed
requires_grad
=
ctx
.
requires_grad
# Check if output is quantized
output_quantizer
=
None
if
(
FP8GlobalStateManager
.
is_fp8_enabled
()
and
next_op
is
not
None
and
next_op
.
num_quantizers
(
"forward"
)
>
0
):
output_quantizer
=
next_op
.
get_quantizer
(
"forward"
,
0
)
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
output_quantizer
=
next_op_input_quantizer
# Compute layer norm
sm_margin
=
self
.
_sm_margins
[
"forward"
if
requires_grad
else
"inference"
]
...
...
@@ -235,12 +226,10 @@ class LayerNorm(BasicOperation):
# Save state for backward pass
if
requires_grad
:
ctx
.
save_for_backward
(
x
,
means
,
rstdevs
)
ctx
.
device
=
device
ctx
.
dtype
=
dtype
ctx
.
has_prev_op
=
prev_op
is
not
None
# Reshape output tensor
out
=
reshape
(
y
,
input_dims
)
out
=
y
.
view
(
input_dims
)
return
out
def
op_backward
(
...
...
@@ -257,14 +246,9 @@ class LayerNorm(BasicOperation):
inner_dim
=
math
.
prod
(
weight_dims
)
# Check input tensors
device
=
ctx
.
device
dtype
=
ctx
.
dtype
dy
=
reshape
(
grad_output
,
x
.
size
(),
device
=
device
,
dtype
=
dtype
)
w
=
reshape
(
self
.
weight
,
(
inner_dim
,),
device
=
device
,
dtype
=
dtype
)
if
isinstance
(
w
,
QuantizedTensor
):
w
=
w
.
dequantize
()
if
isinstance
(
dy
,
QuantizedTensor
):
dy
=
dy
.
dequantize
()
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
dtype
).
view
(
x
.
size
())
w
=
maybe_dequantize
(
self
.
weight
,
dtype
).
view
((
inner_dim
,))
# Compute layer norm backward pass
dx
,
dw
,
db
=
layernorm_bwd
(
...
...
@@ -278,13 +262,22 @@ class LayerNorm(BasicOperation):
)
# Clear saved tensors if possible
if
ctx
.
has_prev_op
:
clear_tensor_data
(
x
)
clear_tensor_data
(
means
)
clear_tensor_data
(
rstdevs
)
# Reshape results
grad_input
=
reshape
(
dx
,
grad_output
.
size
())
grad_weight
=
reshape
(
dw
,
weight_dims
)
grad_bias
=
reshape
(
db
,
weight_dims
)
grad_input
=
dx
.
view
(
grad_output
.
size
())
grad_weight
=
dw
.
view
(
weight_dims
)
grad_bias
=
db
.
view
(
weight_dims
)
return
grad_input
,
(
grad_weight
,
grad_bias
)
def
op_onnx_forward
(
self
,
input_
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Every operand in this function has a defined ONNX translation."""
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
return
torch
.
nn
.
functional
.
layer_norm
(
input_
,
input_
.
shape
[
-
1
:],
weight
,
self
.
bias
,
self
.
eps
)
transformer_engine/pytorch/ops/basic/make_extra_output.py
View file @
44740c6c
...
...
@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
class
MakeExtraOutput
(
BasicOperation
):
...
...
@@ -58,8 +59,8 @@ class MakeExtraOutput(BasicOperation):
input_
:
torch
.
Tensor
,
*
,
basic_op_extra_inputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
basic_op_prev_ops
:
list
[
Optional
[
BasicOperation
]
],
basic_op_next_ops
:
list
[
Optional
[
BasicOperation
]
],
prev_op_grad_input_quantizer
:
Optional
[
Quantizer
],
next_op_input_quantizer
:
Optional
[
Quantizer
],
basic_op_kwargs
:
list
[
dict
[
str
,
Any
]],
)
->
tuple
[
torch
.
Tensor
,
Iterable
[
Iterable
[
torch
.
Tensor
]]]:
return
input_
,
[(
input_
,)]
...
...
@@ -77,4 +78,4 @@ class MakeExtraOutput(BasicOperation):
]:
grad_input
=
basic_op_grad_extra_outputs
[
0
][
0
]
grad_input
+=
grad_output
return
grad_input
,
[],
[()]
return
grad_input
,
[
()
],
[()]
transformer_engine/pytorch/ops/basic/quantize.py
View file @
44740c6c
...
...
@@ -10,8 +10,9 @@ from typing import Optional
import
torch
from
...fp8
import
FP8GlobalStateManager
from
..
.tensor
import
Q
uantized
T
ensor
from
..
_common
import
is_q
uantized
_t
ensor
from
..op
import
BasicOperation
,
OperationContext
from
...tensor
import
Quantizer
class
Quantize
(
BasicOperation
):
...
...
@@ -49,8 +50,8 @@ class Quantize(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
# Check if FP8 is enabled
...
...
@@ -60,7 +61,7 @@ class Quantize(BasicOperation):
# Quantize if needed
out
=
input_
if
quantize_forward
and
not
is
instance
(
out
,
Q
uantized
T
ensor
):
if
quantize_forward
and
not
is
_q
uantized
_t
ensor
(
out
):
out
=
self
.
get_quantizer
(
"forward"
,
0
)(
out
)
ctx
.
quantize_backward
=
quantize_backward
...
...
@@ -72,6 +73,6 @@ class Quantize(BasicOperation):
grad_output
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
tuple
[()]]:
grad_input
=
grad_output
if
ctx
.
quantize_backward
and
not
is
instance
(
grad_input
,
Q
uantized
T
ensor
):
if
ctx
.
quantize_backward
and
not
is
_q
uantized
_t
ensor
(
grad_input
):
grad_input
=
self
.
get_quantizer
(
"backward"
,
0
)(
grad_input
)
return
grad_input
,
()
transformer_engine/pytorch/ops/basic/reduce_scatter.py
View file @
44740c6c
...
...
@@ -10,8 +10,9 @@ from typing import Optional
import
torch
from
...distributed
import
gather_along_first_dim
from
..
.tensor
import
QuantizedTensor
from
..
_common
import
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
from
...tensor
import
Quantizer
class
ReduceScatter
(
BasicOperation
):
...
...
@@ -39,8 +40,8 @@ class ReduceScatter(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
# Trivial case
...
...
@@ -59,10 +60,7 @@ class ReduceScatter(BasicOperation):
output_dims
[
0
]
//=
self
.
process_group_size
# Check input tensor
x
=
input_
if
isinstance
(
x
,
QuantizedTensor
):
x
=
x
.
dequantize
()
x
=
x
.
contiguous
()
x
=
maybe_dequantize
(
input_
.
contiguous
())
# Perform reduce-scatter
y
=
torch
.
empty
(
output_dims
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
...
...
transformer_engine/pytorch/ops/basic/reshape.py
View file @
44740c6c
...
...
@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation
,
OperationContext
,
)
from
...tensor
import
Quantizer
class
Reshape
(
BasicOperation
):
...
...
@@ -37,8 +38,8 @@ class Reshape(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
ctx
.
input_shape
=
input_
.
size
()
return
input_
.
reshape
(
*
self
.
_shape
)
...
...
transformer_engine/pytorch/ops/basic/rmsnorm.py
View file @
44740c6c
...
...
@@ -14,7 +14,6 @@ import torch
from
transformer_engine_torch
import
rmsnorm_bwd
,
rmsnorm_fwd
from
...fp8
import
FP8GlobalStateManager
from
...tensor
import
QuantizedTensor
from
...constants
import
TE_DType
from
...utils
import
(
canonicalize_device
,
...
...
@@ -23,7 +22,9 @@ from ...utils import (
devices_match
,
)
from
..op
import
BasicOperation
,
OperationContext
from
.._common
import
maybe_autocast_dtype
,
reshape
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
...export
import
is_in_onnx_export_mode
from
...tensor
import
Quantizer
class
RMSNorm
(
BasicOperation
):
...
...
@@ -150,8 +151,8 @@ class RMSNorm(BasicOperation):
weight
=
torch
.
nn
.
Parameter
(
weight
)
self
.
weight
=
weight
def
pre_forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_forward
(
*
args
,
**
kwargs
)
def
pre_
first_
forward
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
pre_
first_
forward
(
*
args
,
**
kwargs
)
if
self
.
weight
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
...
...
@@ -159,9 +160,11 @@ class RMSNorm(BasicOperation):
self
,
ctx
:
OperationContext
,
input_
:
torch
.
Tensor
,
prev_op
:
Optional
[
BasicOperation
]
=
None
,
next_op
:
Optional
[
BasicOperation
]
=
None
,
prev_op
_grad_input_quantizer
:
Optional
[
Quantizer
]
,
next_op
_input_quantizer
:
Optional
[
Quantizer
]
,
)
->
torch
.
Tensor
:
if
is_in_onnx_export_mode
():
return
self
.
op_onnx_forward
(
input_
)
# Check tensor dims
weight
=
self
.
weight
...
...
@@ -175,28 +178,18 @@ class RMSNorm(BasicOperation):
# Check input tensors
inner_dim
=
math
.
prod
(
weight_dims
)
device
=
weight
.
device
if
device
.
type
!=
"cuda"
:
device
=
canonicalize_device
(
None
)
dtype
=
maybe_autocast_dtype
(
default_dtype
=
weight
.
dtype
)
x
=
reshape
(
input_
,
(
-
1
,
inner_dim
),
device
=
device
,
dtype
=
dtype
)
w
=
reshape
(
self
.
weight
,
(
inner_dim
,),
device
=
device
,
dtype
=
dtype
)
if
isinstance
(
x
,
QuantizedTensor
):
x
=
x
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensor
):
w
=
w
.
dequantize
()
x
=
maybe_dequantize
(
input_
.
contiguous
(),
dtype
).
view
((
-
1
,
inner_dim
))
w
=
maybe_dequantize
(
self
.
weight
,
dtype
).
view
((
inner_dim
,))
# Check if backward pass is needed
requires_grad
=
ctx
.
requires_grad
# Check if output is quantized
output_quantizer
=
None
if
(
FP8GlobalStateManager
.
is_fp8_enabled
()
and
next_op
is
not
None
and
next_op
.
num_quantizers
(
"forward"
)
>
0
):
output_quantizer
=
next_op
.
get_quantizer
(
"forward"
,
0
)
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
output_quantizer
=
next_op_input_quantizer
# Compute RMSNorm
sm_margin
=
self
.
_sm_margins
[
"forward"
if
requires_grad
else
"inference"
]
...
...
@@ -214,12 +207,10 @@ class RMSNorm(BasicOperation):
# Save state for backward pass
if
requires_grad
:
ctx
.
save_for_backward
(
x
,
rstdevs
)
ctx
.
device
=
device
ctx
.
dtype
=
dtype
ctx
.
has_prev_op
=
prev_op
is
not
None
# Reshape output tensor
out
=
reshape
(
y
,
input_dims
)
out
=
y
.
view
(
input_dims
)
return
out
def
op_backward
(
...
...
@@ -236,14 +227,9 @@ class RMSNorm(BasicOperation):
inner_dim
=
math
.
prod
(
weight_dims
)
# Check input tensors
device
=
ctx
.
device
dtype
=
ctx
.
dtype
dy
=
reshape
(
grad_output
,
x
.
size
(),
device
=
device
,
dtype
=
dtype
)
w
=
reshape
(
self
.
weight
,
(
inner_dim
,),
device
=
device
,
dtype
=
dtype
)
if
isinstance
(
w
,
QuantizedTensor
):
w
=
w
.
dequantize
()
if
isinstance
(
dy
,
QuantizedTensor
):
dy
=
dy
.
dequantize
()
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
dtype
).
view
(
x
.
size
())
w
=
maybe_dequantize
(
self
.
weight
,
dtype
).
view
((
inner_dim
,))
# Compute RMSNorm backward pass
dx
,
dw
=
rmsnorm_bwd
(
...
...
@@ -256,11 +242,18 @@ class RMSNorm(BasicOperation):
)
# Clear saved tensors if possible
if
ctx
.
has_prev_op
:
clear_tensor_data
(
x
)
clear_tensor_data
(
rstdevs
)
# Reshape results
grad_input
=
reshape
(
dx
,
grad_output
.
size
())
grad_weight
=
reshape
(
dw
,
weight_dims
)
grad_input
=
dx
.
view
(
grad_output
.
size
())
grad_weight
=
dw
.
view
(
weight_dims
)
return
grad_input
,
(
grad_weight
,)
def
op_onnx_forward
(
self
,
input_
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Every operand in this function has a defined ONNX translation."""
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
return
torch
.
nn
.
functional
.
rms_norm
(
input_
,
input_
.
shape
[
-
1
:],
weight
,
self
.
eps
)
transformer_engine/pytorch/ops/fused/__init__.py
View file @
44740c6c
...
...
@@ -4,6 +4,10 @@
"""Compound tensor operation supported by the operation fuser."""
from
.backward_bias_activation
import
(
BackwardBiasActivation
,
fuse_backward_bias_activation
,
)
from
.backward_linear_add
import
(
BackwardLinearAdd
,
fuse_backward_linear_add
,
...
...
transformer_engine/pytorch/ops/fused/backward_bias_activation.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused backward dbias + dact + quantize."""
from
__future__
import
annotations
from
typing
import
Optional
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
Recipe
from
transformer_engine.pytorch.ops.basic
import
Bias
from
transformer_engine.pytorch.ops.basic.activation
import
(
_ActivationOperation
,
GELU
,
ReLU
,
)
from
transformer_engine.pytorch.ops.op
import
(
FusedOperation
,
FusibleOperation
,
OperationContext
,
)
from
...utils
import
clear_tensor_data
from
.._common
import
maybe_dequantize
_fused_activations
=
{
GELU
:
tex
.
dbias_dgelu
,
ReLU
:
tex
.
dbias_drelu
}
_fusible_activations
=
tuple
(
_fused_activations
.
keys
())
class
BackwardBiasActivation
(
FusedOperation
):
"""Fused backward dbias + dact + quantize
Uses the next operation's input quantizer.
"""
def
__init__
(
self
,
*
,
bias
:
Bias
,
activation
:
_ActivationOperation
):
super
().
__init__
((
bias
,
activation
))
self
.
_fused_function
=
_fused_activations
[
type
(
activation
)]
def
fuser_backward
(
self
,
basic_op_ctxs
:
list
[
OperationContext
],
grad_output
:
torch
.
Tensor
,
*
,
basic_op_grad_extra_outputs
:
list
[
tuple
[
torch
.
Tensor
,
...]],
)
->
tuple
[
torch
.
Tensor
,
list
[
tuple
[
Optional
[
torch
.
Tensor
],
...]],
list
[
tuple
[()]],
]:
# Get basic operation contexts
activation_op_ctx
=
basic_op_ctxs
[
0
]
bias_op_ctx
=
basic_op_ctxs
[
1
]
# Saved tensors from forward pass
(
act_input
,)
=
activation_op_ctx
.
saved_tensors
# Check activation input tensor
act_input
=
maybe_dequantize
(
act_input
.
contiguous
(),
activation_op_ctx
.
dtype
)
# Check grad output tensor
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
act_input
.
dtype
)
# Get previous op quantizer
if
not
bias_op_ctx
.
with_quantized_compute
:
raise
RuntimeError
(
"BackwardBiasActivation requires quantized compute, "
"but Bias context has it disabled"
)
quantizer
=
bias_op_ctx
.
grad_input_quantizer
if
quantizer
is
None
:
raise
RuntimeError
(
"BackwardBiasActivation requires previous op's grad output quantizer, "
"but Bias context has no quantizer"
)
# Launch kernel
db
,
dx
=
self
.
_fused_function
(
dy
,
act_input
,
quantizer
)
# Clear activation input tensor
clear_tensor_data
(
act_input
)
return
dx
,
[(),
(
db
,)],
[(),
()]
def
fuse_backward_bias_activation
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
recipe
:
Optional
[
Recipe
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Fused backward dbias + dact + quantize
Parameters
----------
ops: list of tuples
Backward pass operations and the indices of the corresponding
basic operations.
recipe: Recipe, optional
Used quantization recipe
Returns
-------
ops: list of tuples
Updated backward pass operations
"""
# Check if recipe supports bias activation fusion
if
recipe
is
None
or
not
(
recipe
.
delayed
()
or
recipe
.
mxfp8
()):
return
ops
# Scan through ops, fusing if possible
out
=
[]
window
=
[]
while
len
(
ops
)
>=
3
:
out
.
extend
(
window
)
# Check if first op is a supported activation
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
_fusible_activations
):
continue
# Check if second op is bias
op
,
_
=
ops
[
0
]
if
not
isinstance
(
op
,
Bias
):
continue
# Check if third op has a grad input quantizer
op
,
_
=
ops
[
1
]
if
not
op
.
num_quantizers
(
"backward"
)
>
0
:
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
BackwardBiasActivation
(
activation
=
window
[
0
][
0
],
bias
=
window
[
1
][
0
],
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
Prev
1
…
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment