Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
387 additions
and
157 deletions
+387
-157
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+38
-41
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+55
-18
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+88
-25
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+58
-19
transformer_engine/pytorch/ops/_common.py
transformer_engine/pytorch/ops/_common.py
+5
-5
transformer_engine/pytorch/ops/basic/__init__.py
transformer_engine/pytorch/ops/basic/__init__.py
+13
-1
transformer_engine/pytorch/ops/basic/activation.py
transformer_engine/pytorch/ops/basic/activation.py
+36
-0
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+54
-29
transformer_engine/pytorch/ops/basic/dropout.py
transformer_engine/pytorch/ops/basic/dropout.py
+2
-2
transformer_engine/pytorch/ops/basic/quantize.py
transformer_engine/pytorch/ops/basic/quantize.py
+3
-3
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
...rmer_engine/pytorch/ops/fused/backward_activation_bias.py
+1
-1
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+1
-0
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+1
-0
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+2
-2
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+2
-2
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+2
-2
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+1
-0
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+3
-3
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+5
-1
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+17
-3
No files found.
transformer_engine/pytorch/module/grouped_linear.py
View file @
063ef88d
...
...
@@ -22,7 +22,7 @@ from .base import (
_2X_ACC_WGRAD
,
)
from
._common
import
WeightGradStore
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..utils
import
(
divide
,
cast_if_needed
,
...
...
@@ -46,7 +46,7 @@ from ..cpu_offload import is_cpu_offload_enabled
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.quantized_tensor
import
(
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
...
...
@@ -204,39 +204,27 @@ class _GroupedLinear(torch.autograd.Function):
inputmats
[
0
]
=
inp
else
:
for
inputmat
in
inputmats
:
if
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
):
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
inputmats
=
[
None
]
*
num_gemms
if
inp
.
requires_grad
:
for
weight
in
weights_fp8
:
if
isinstance
(
weight
,
QuantizedTensor
Bas
e
):
if
isinstance
(
weight
,
QuantizedTensor
Storag
e
):
weight
.
update_usage
(
columnwise_usage
=
True
)
for
i
in
range
(
num_gemms
):
weights
[
i
].
offloading_activation
=
False
weights_fp8
[
i
].
offloading_activation
=
False
biases
[
i
].
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
fine_grained_activation_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if
cpu_offloading
:
ctx
.
grad_added_to_main_grad
=
hasattr
(
weights
[
0
],
"grad_added_to_main_grad"
)
if
(
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
and
fuse_wgrad_accumulation
):
grad_added_to_main_grad_list
=
[]
if
ctx
.
grad_added_to_main_grad
:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx
.
weight_objects
=
[]
for
weight
in
weights
:
if
weight
.
requires_grad
and
hasattr
(
weight
,
"grad_added_to_main_grad"
):
grad_added_to_main_grad_list
.
append
(
weight
.
grad_added_to_main_grad
)
weight
.
grad_added_to_main_grad
=
True
else
:
grad_added_to_main_grad_list
.
append
(
None
)
ctx
.
grad_added_to_main_grad_list
=
grad_added_to_main_grad_list
ctx
.
weight_objects
.
append
(
weight
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
inputmats
,
...
...
@@ -300,13 +288,15 @@ class _GroupedLinear(torch.autograd.Function):
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
main_grads
=
[
main_grad_func
()
for
main_grad_func
in
ctx
.
main_grad_funcs
]
if
(
ctx
.
cpu_offloading
or
ctx
.
fine_grained_activation_offloading
)
and
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
ctx
.
num_gemms
):
w
=
torch
.
nn
.
Parameter
(
weights
[
i
],
weights
[
i
].
requires_grad
)
w
.
main_grad
=
main_grads
[
i
]
weights
[
i
]
=
w
if
ctx
.
fine_grained_activation_offloading
and
weights
[
0
].
requires_grad
:
weights
[
i
].
grad_added_to_main_grad
=
ctx
.
grad_added_to_main_grad_list
[
i
]
if
ctx
.
cpu_offloading
:
if
ctx
.
grad_added_to_main_grad
:
for
i
,
weight
in
enumerate
(
ctx
.
weight_objects
):
origin_weights
[
i
]
=
ctx
.
weight_objects
[
i
]
ctx
.
weight_objects
[
i
]
=
None
if
ctx
.
fuse_wgrad_accumulation
:
for
i
in
range
(
N
):
origin_weights
[
i
].
main_grad
=
main_grads
[
i
]
# Preprocess grad output
grad_output_view
=
grad_output
.
contiguous
().
view
(
-
1
,
grad_output
.
shape
[
-
1
])
...
...
@@ -369,7 +359,7 @@ class _GroupedLinear(torch.autograd.Function):
)
for
weight
,
quantizer
in
zip
(
weights
,
ctx
.
weight_quantizers
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Bas
e
):
if
quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Storag
e
):
weight
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
...
...
@@ -433,7 +423,11 @@ class _GroupedLinear(torch.autograd.Function):
use_bias
=
ctx
.
use_bias
if
grad_biases
[
0
]
is
None
else
None
,
bias
=
biases
,
use_split_accumulator
=
wgrad_gemm_use_split_accumulator
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
accumulate
=
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
weights
[
0
],
"overwrite_main_grad"
,
False
)
else
False
),
)
# WGRAD
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
...
...
@@ -555,7 +549,9 @@ class GroupedLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
...
...
@@ -772,7 +768,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
"""
assert
not
isinstance
(
inp
,
QuantizedTensor
Bas
e
inp
,
QuantizedTensor
Storag
e
),
"GroupedLinear doesn't support input tensor in FP8."
assert
len
(
m_splits
)
==
self
.
num_gemms
,
"Number of splits should match number of GEMMs."
...
...
@@ -907,16 +903,17 @@ class GroupedLinear(TransformerEngineBaseModule):
self
.
_offsets
[
"input"
]
+
i
*
self
.
_num_fp8_tensors_per_gemm
[
"bwd"
]
].
amax_epsilon
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Bas
e
]]:
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Storag
e
]]:
"""Get the weight tensors of the module."""
weight_tensors
=
[
getattr
(
self
,
f
"weight
{
i
}
"
)
for
i
in
range
(
self
.
num_gemms
)]
if
not
self
.
fp8
and
any
(
isinstance
(
w
,
QuantizedTensor
Bas
e
)
for
w
in
weight_tensors
):
if
not
self
.
fp8
and
any
(
isinstance
(
w
,
QuantizedTensor
Storag
e
)
for
w
in
weight_tensors
):
warnings
.
warn
(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors
=
[
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensorBase
)
else
w
for
w
in
weight_tensors
w
.
dequantize
()
if
isinstance
(
w
,
QuantizedTensorStorage
)
else
w
for
w
in
weight_tensors
]
return
weight_tensors
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
063ef88d
...
...
@@ -16,6 +16,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_experimental
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
...
...
@@ -26,9 +27,10 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..utils
import
(
assert_dim_for_fp8_exec
,
assert_dim_for_all_gather
,
cast_if_needed
,
clear_tensor_data
,
divide
,
...
...
@@ -57,7 +59,7 @@ from ..graph import is_graph_capturing
from
._common
import
apply_normalization
,
noop_cat
,
WeightGradStore
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
...
...
@@ -65,8 +67,8 @@ from ..tensor.quantized_tensor import (
from
...debug.pytorch.debug_state
import
TEDebugState
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.
_internal.mxfp8
_tensor_
bas
e
import
MXFP8TensorBas
e
from
..tensor.
_internal.float8_blockwise
_tensor_
bas
e
import
Float8BlockwiseQTensorBas
e
from
..tensor.
storage.float8_blockwise
_tensor_
storag
e
import
Float8BlockwiseQTensorStorag
e
from
..tensor.
storage.mxfp8
_tensor_
storag
e
import
MXFP8TensorStorag
e
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
...
...
@@ -144,6 +146,8 @@ class _LayerNormLinear(torch.autograd.Function):
if
ub_name
is
not
None
:
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ub_name
}
"
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
# Make sure input dimensions are compatible
out_features
,
in_features
=
weight
.
shape
inp_shape
=
inp
.
shape
...
...
@@ -153,6 +157,7 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat
=
inp
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_all_gather
(
inputmat
,
with_input_all_gather
,
input_quantizer
)
# Cast for native AMP
nvtx_range_push
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
...
...
@@ -166,7 +171,6 @@ class _LayerNormLinear(torch.autograd.Function):
weight_requires_grad
=
weight
.
requires_grad
backward_needs_input
=
is_grad_enabled
and
weight_requires_grad
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
# Configure Userbuffers communication (comm+GEMM overlap)
if
debug
:
# turn off userbuffers in debug mode
...
...
@@ -199,11 +203,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
experimental
=
is_experimental
(
input_quantizer
)
with_quantized_norm
=
(
fp8
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
experimental
# TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
)
# Apply normalization
...
...
@@ -249,7 +255,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer
=
None
if
fp8
or
debug
:
quantizer
=
input_quantizer
if
not
with_quantized_norm
:
# experimental recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
experimental
:
ln_out
=
quantizer
(
ln_out
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
...
...
@@ -280,7 +287,7 @@ class _LayerNormLinear(torch.autograd.Function):
weightmat
=
weight
quantized_weight
=
False
if
fp8
or
debug
:
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensor
Bas
e
)
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensor
Storag
e
)
# Configure quantizer
if
weight_quantizer
is
not
None
:
...
...
@@ -405,18 +412,18 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM.
if
backward_needs_input
:
if
isinstance
(
ln_out
,
QuantizedTensor
Bas
e
):
if
isinstance
(
ln_out
,
QuantizedTensor
Storag
e
):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if
(
isinstance
(
ln_out
,
(
MXFP8Tensor
Bas
e
,
Float8BlockwiseQTensor
Bas
e
))
isinstance
(
ln_out
,
(
MXFP8Tensor
Storag
e
,
Float8BlockwiseQTensor
Storag
e
))
or
not
ctx
.
ln_out_needs_gather
):
ln_out
.
update_usage
(
rowwise_usage
=
False
)
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
weightmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
weightmat
,
QuantizedTensor
Storag
e
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
...
...
@@ -716,9 +723,9 @@ class _LayerNormLinear(torch.autograd.Function):
# --------------------------------------------------
# Make sure required data is available
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Bas
e
):
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight
,
QuantizedTensor
Storag
e
):
weight
.
update_usage
(
columnwise_usage
=
True
)
# Choose whether to use GEMM kernel with split accumulator
...
...
@@ -837,14 +844,14 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
ln_out_total
,
QuantizedTensor
Bas
e
):
if
isinstance
(
ln_out_total
,
QuantizedTensor
Storag
e
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
ln_out_total
=
ctx
.
input_quantizer
(
ln_out_total
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -880,7 +887,11 @@ class _LayerNormLinear(torch.autograd.Function):
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"accumulate"
:
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
weight
,
"overwrite_main_grad"
,
False
)
else
False
),
"layout"
:
"NT"
,
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
...
...
@@ -1037,7 +1048,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
# Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor
Bas
e):
# if ctx.fp8 and not isinstance(weight, QuantizedTensor
Storag
e):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return
(
...
...
@@ -1164,7 +1175,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
...
...
@@ -1470,6 +1483,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif other recipes (mxfp8, etc)
def
reset_layer_norm_parameters
(
self
)
->
None
:
...
...
@@ -1812,7 +1827,29 @@ class LayerNormLinear(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
def
_customize_quantizers_nvfp4
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert
recipe
.
nvfp4
(),
"Incorrect recipe."
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
# set input_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# customize grad_output_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorStorage
]]:
"""Get the weight tensors of the module."""
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
063ef88d
...
...
@@ -18,6 +18,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_experimental
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
...
...
@@ -28,7 +29,7 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..jit
import
(
bias_gelu_fused
,
bgrad_dgelu_fused
,
...
...
@@ -41,6 +42,7 @@ from ..utils import (
init_method_constant
,
cast_if_needed
,
assert_dim_for_fp8_exec
,
assert_dim_for_all_gather
,
clear_tensor_data
,
requires_grad
,
needs_quantized_gemm
,
...
...
@@ -65,11 +67,12 @@ from ..tensor.float8_tensor import (
Float8Tensor
,
)
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.nvfp4_tensor
import
NVFP4Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
._common
import
apply_normalization
,
WeightGradStore
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..tensor.quantized_tensor
import
(
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
...
...
@@ -120,8 +123,14 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: []
if
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
():
# Per-tensor current scaling or fp8 blockwise scaling or custom quantization: []
# TODO(ksivaman): Fuse nvfp4 act once kernel is available.
if
(
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
()
or
recipe
.
nvfp4
()
or
recipe
.
custom
()
):
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
...
...
@@ -218,6 +227,7 @@ class _LayerNormMLP(torch.autograd.Function):
inputmat
=
inp
.
view
((
-
1
,
in_features
))
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
fc1_weight
,
fc2_weight
)
assert_dim_for_all_gather
(
inputmat
,
sequence_parallel
,
fc1_input_quantizer
)
activation_func
=
_act_func
(
activation
,
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
...
...
@@ -265,11 +275,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
experimental
=
is_experimental
(
fc1_input_quantizer
)
with_quantized_norm
=
(
fp8
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
experimental
)
# Apply normalization
...
...
@@ -309,7 +321,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer
=
None
if
fp8
or
debug
:
quantizer
=
fc1_input_quantizer
if
not
with_quantized_norm
:
# experimental recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
experimental
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
...
...
@@ -447,10 +460,18 @@ class _LayerNormMLP(torch.autograd.Function):
act_out
=
fc2_input_quantizer
(
act_out
)
else
:
fc1_out
,
*
_
=
fc1_outputs
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_block_scaling
():
# tex.quantize does not support GELU fusion for blockwise.
if
fp8
:
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_block_scaling
():
# tex.quantize does not support GELU fusion for blockwise
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
tex
.
quantize
(
act_out
,
fc2_input_quantizer
)
elif
recipe
.
custom
():
# tex.quantize does not support custom quantizers
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
fc2_input_quantizer
(
act_out
)
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
)
else
:
if
fp8_calibration
:
act_out
=
activation_func
(
fc1_out
,
None
)
...
...
@@ -521,9 +542,9 @@ class _LayerNormMLP(torch.autograd.Function):
if
is_grad_enabled
:
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
fc1_weight_final
,
QuantizedTensor
Bas
e
):
if
isinstance
(
fc1_weight_final
,
QuantizedTensor
Storag
e
):
fc1_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
Bas
e
):
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
Storag
e
):
fc2_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
...
...
@@ -555,6 +576,7 @@ class _LayerNormMLP(torch.autograd.Function):
if
not
fc2_weight
.
requires_grad
:
clear_tensor_data
(
act_out
)
act_out
=
None
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
inputmat
,
ln_weight
,
...
...
@@ -680,6 +702,7 @@ class _LayerNormMLP(torch.autograd.Function):
mu
,
rsigma
,
)
=
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx
.
tensor_objects
=
None
...
...
@@ -820,10 +843,10 @@ class _LayerNormMLP(torch.autograd.Function):
)
# Make sure required data is available
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
fc2_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc2_weight
,
QuantizedTensor
Bas
e
ctx
.
fc2_weight
,
QuantizedTensor
Storag
e
):
ctx
.
fc2_weight
.
update_usage
(
columnwise_usage
=
True
)
...
...
@@ -905,14 +928,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
act_out
,
QuantizedTensor
Bas
e
):
if
isinstance
(
act_out
,
QuantizedTensor
Storag
e
):
act_out
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc2_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
act_out
=
ctx
.
fc2_input_quantizer
(
act_out
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc2_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -932,7 +955,11 @@ class _LayerNormMLP(torch.autograd.Function):
else
ctx
.
activation_dtype
),
"quantization_params"
:
ctx
.
fc2_grad_weight_quantizer
,
# wgrad in high precision
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"accumulate"
:
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
fc1_weight
,
"overwrite_main_grad"
,
False
)
else
False
),
"layout"
:
"NT"
,
"out"
:
origin_fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
fc2_bias
if
fc2_bias
is
not
None
and
fc2_bias_grad
is
None
else
None
,
...
...
@@ -1028,8 +1055,11 @@ class _LayerNormMLP(torch.autograd.Function):
)
# activation in high precision
if
ctx
.
fp8
:
# TODO float8 blockwise current scaling has no bgrad fusion for now
if
isinstance
(
ctx
.
fc1_grad_output_quantizer
,
Float8BlockQuantizer
):
# TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now
if
(
isinstance
(
ctx
.
fc1_grad_output_quantizer
,
Float8BlockQuantizer
)
or
ctx
.
fp8_recipe
.
custom
()
):
fc1_bias_grad
=
dact
.
view
(
-
1
,
dact
.
shape
[
-
1
]).
sum
(
dim
=
0
)
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
else
:
...
...
@@ -1074,7 +1104,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Make sure required data is available
if
ctx
.
fc1_weight_quantizer
is
not
None
and
isinstance
(
ctx
.
fc1_weight_quantizer
,
QuantizedTensor
Bas
e
ctx
.
fc1_weight_quantizer
,
QuantizedTensor
Storag
e
):
ctx
.
fc1_weight
.
update_usage
(
columnwise_usage
=
True
)
...
...
@@ -1145,7 +1175,7 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total_work
.
wait
()
ln_out_total_work
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
ln_out_total
,
QuantizedTensor
Bas
e
):
if
isinstance
(
ln_out_total
,
QuantizedTensor
Storag
e
):
ln_out_total
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc1_input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -1155,7 +1185,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
dact
,
QuantizedTensor
Bas
e
):
if
isinstance
(
dact
,
QuantizedTensor
Storag
e
):
dact
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
fc1_grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -1178,7 +1208,11 @@ class _LayerNormMLP(torch.autograd.Function):
else
ctx
.
activation_dtype
),
"quantization_params"
:
ctx
.
fc1_grad_weight_quantizer
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"accumulate"
:
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
fc2_weight
,
"overwrite_main_grad"
,
False
)
else
False
),
"layout"
:
"NT"
,
"out"
:
origin_fc1_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
...
...
@@ -1486,7 +1520,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the
...
...
@@ -1718,6 +1754,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
def
reset_layer_norm_parameters
(
self
)
->
None
:
...
...
@@ -1937,7 +1975,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
)),
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
,
NVFP4Quantizer
),
),
)
fc1_input_quantizer
.
internal
=
True
if
fp8_output
:
...
...
@@ -2142,7 +2183,29 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
def
_customize_quantizers_nvfp4
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert
recipe
.
nvfp4
(),
"Incorrect recipe."
if
fwd
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorStorage
]]:
"""Get the weight tensors of the module."""
return
[
self
.
fc1_weight
,
self
.
fc2_weight
]
...
...
transformer_engine/pytorch/module/linear.py
View file @
063ef88d
...
...
@@ -27,7 +27,7 @@ from .base import (
_2X_ACC_WGRAD
,
)
from
._common
import
noop_cat
,
WeightGradStore
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..utils
import
(
cast_if_needed
,
clear_tensor_data
,
...
...
@@ -36,6 +36,7 @@ from ..utils import (
requires_grad
,
needs_quantized_gemm
,
assert_dim_for_fp8_exec
,
assert_dim_for_all_gather
,
nvtx_range_pop
,
nvtx_range_push
,
get_activation_offloading
,
...
...
@@ -60,13 +61,14 @@ from ..jit import no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
Bas
e
,
QuantizedTensor
Storag
e
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
)
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.utils
import
is_experimental
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...debug.pytorch.debug_state
import
TEDebugState
...
...
@@ -154,6 +156,9 @@ class _Linear(torch.autograd.Function):
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
# experimental recipe check
experimental
=
is_experimental
(
input_quantizer
)
or
is_experimental
(
weight_quantizer
)
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
...
...
@@ -164,6 +169,7 @@ class _Linear(torch.autograd.Function):
own_quantized_input
=
False
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_all_gather
(
inputmat
,
with_input_all_gather_nccl
,
input_quantizer
)
if
save_original_input
:
assert
not
isinstance
(
input_quantizer
,
Float8Quantizer
...
...
@@ -175,7 +181,7 @@ class _Linear(torch.autograd.Function):
if
fp8
or
debug
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
not
isinstance
(
inputmat
,
QuantizedTensor
Base
)
:
if
not
isinstance
(
inputmat
,
QuantizedTensor
Storage
)
and
not
experimental
:
own_quantized_input
=
True
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
if
isinstance
(
...
...
@@ -213,7 +219,7 @@ class _Linear(torch.autograd.Function):
else
:
# Do not all-gather input tensor
if
fp8
or
debug
:
if
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
):
inputmat
.
update_usage
(
rowwise_usage
=
True
)
else
:
if
input_quantizer
is
None
:
...
...
@@ -369,7 +375,7 @@ class _Linear(torch.autograd.Function):
if
(
backward_needs_input
and
own_quantized_input
and
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
)
and
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
)
):
if
(
ctx
.
backward_input_needs_gather
...
...
@@ -388,7 +394,7 @@ class _Linear(torch.autograd.Function):
# Weight with column-wise usage is needed for dgrad GEMM.
if
inp
.
requires_grad
:
if
isinstance
(
weightmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
weightmat
,
QuantizedTensor
Storag
e
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
and
saved_inputmat
is
not
None
:
...
...
@@ -401,7 +407,7 @@ class _Linear(torch.autograd.Function):
ctx
.
fsdp_shapes
=
_fsdp_scatter_tensors
(
fsdp_group
,
saved_inputmat
,
weightmat
if
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Bas
e
)
else
None
,
weightmat
if
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Storag
e
)
else
None
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
...
...
@@ -471,6 +477,7 @@ class _Linear(torch.autograd.Function):
ctx
.
main_grad_func
=
lambda
:
weight
.
main_grad
ctx
.
debug
=
debug
ctx
.
experimental
=
experimental
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_bias
=
bias
is
not
None
...
...
@@ -637,10 +644,10 @@ class _Linear(torch.autograd.Function):
inputmat_total_work
=
None
if
ctx
.
requires_wgrad
:
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
):
# Input tensor is already quantized
pass
elif
ctx
.
debug
:
elif
ctx
.
debug
or
ctx
.
experimental
:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else
:
...
...
@@ -656,7 +663,7 @@ class _Linear(torch.autograd.Function):
quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
inputmat
=
quantizer
(
inputmat
)
else
:
if
isinstance
(
inputmat
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat
,
QuantizedTensor
Storag
e
):
inputmat
=
inputmat
.
dequantize
(
dtype
=
ctx
.
activation_dtype
)
else
:
inputmat
=
cast_if_needed
(
inputmat
,
ctx
.
activation_dtype
)
...
...
@@ -701,9 +708,11 @@ class _Linear(torch.autograd.Function):
if
ctx
.
requires_dgrad
:
# Make sure required data is available
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
rowwise_usage
=
True
)
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight_fp8
,
QuantizedTensorBase
):
if
ctx
.
weight_quantizer
is
not
None
and
isinstance
(
weight_fp8
,
QuantizedTensorStorage
):
weight_fp8
.
update_usage
(
columnwise_usage
=
True
)
# Choose whether to use GEMM kernel with split accumulator
...
...
@@ -729,6 +738,7 @@ class _Linear(torch.autograd.Function):
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight_fp8
,
...
...
@@ -786,7 +796,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work
.
wait
()
inputmat_total_work
=
None
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
inputmat_total
,
QuantizedTensor
Bas
e
):
if
isinstance
(
inputmat_total
,
QuantizedTensor
Storag
e
):
inputmat_total
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -828,7 +838,7 @@ class _Linear(torch.autograd.Function):
)
if
ctx
.
fp8
or
ctx
.
debug
:
if
isinstance
(
grad_output
,
QuantizedTensor
Bas
e
):
if
isinstance
(
grad_output
,
QuantizedTensor
Storag
e
):
grad_output
.
update_usage
(
columnwise_usage
=
True
)
else
:
ctx
.
grad_output_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
...
...
@@ -864,7 +874,11 @@ class _Linear(torch.autograd.Function):
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
"quantization_params"
:
ctx
.
grad_weight_quantizer
,
"accumulate"
:
accumulate_wgrad_into_param_main_grad
,
"accumulate"
:
(
accumulate_wgrad_into_param_main_grad
if
not
getattr
(
weight
,
"overwrite_main_grad"
,
False
)
else
False
),
"layout"
:
"NT"
,
"out"
:
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
"bias"
:
(
bias
if
(
grad_bias
is
None
and
not
ctx
.
fp8
)
else
None
),
...
...
@@ -984,7 +998,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop
(
f
"
{
nvtx_label
}
.reduce_and_update_fp8_tensors"
)
# Scatter fp8 weight buffers
if
ctx
.
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Bas
e
):
if
ctx
.
fp8
and
not
isinstance
(
weight
,
QuantizedTensor
Storag
e
):
_fsdp_scatter_tensors
(
ctx
.
fsdp_group
,
weight_fp8
)
return
(
wgrad
,
...
...
@@ -1086,7 +1100,9 @@ class Linear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
...
...
@@ -1363,6 +1379,8 @@ class Linear(TransformerEngineBaseModule):
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
def
reset_parameters
(
self
,
defer_init
=
False
):
...
...
@@ -1452,7 +1470,6 @@ class Linear(TransformerEngineBaseModule):
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
)
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
...
...
@@ -1557,7 +1574,7 @@ class Linear(TransformerEngineBaseModule):
for
name
,
q
in
zip
(
names
,
original_quantizers
)
)
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Bas
e
]]:
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Storag
e
]]:
"""Get the weight tensors of the module."""
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
if
any
(
isinstance
(
w
,
QuantizedTensor
)
for
w
in
unfused_weights
):
...
...
@@ -1693,6 +1710,28 @@ class Linear(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_customize_quantizers_nvfp4
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + linear."""
assert
recipe
.
nvfp4
(),
"Incorrect recipe."
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
# customize input_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# customize grad_output_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
...
...
transformer_engine/pytorch/ops/_common.py
View file @
063ef88d
...
...
@@ -11,19 +11,19 @@ import torch
from
transformer_engine_torch
import
FP8TensorMeta
from
..
import
torch_version
from
..
fp8
import
FP8GlobalStateManager
from
..
quantization
import
FP8GlobalStateManager
from
..tensor.float8_tensor
import
Float8Tensor
from
..tensor.quantized_tensor
import
QuantizedTensor
Bas
e
from
..tensor.quantized_tensor
import
QuantizedTensor
Storag
e
from
..utils
import
canonicalize_dtype
def
is_quantized_tensor
(
tensor
:
torch
.
Tensor
|
QuantizedTensor
Bas
e
)
->
bool
:
def
is_quantized_tensor
(
tensor
:
torch
.
Tensor
|
QuantizedTensor
Storag
e
)
->
bool
:
"""Check if tensor is a quantized tensor"""
return
isinstance
(
tensor
,
QuantizedTensor
Bas
e
)
return
isinstance
(
tensor
,
QuantizedTensor
Storag
e
)
def
maybe_dequantize
(
tensor
:
torch
.
Tensor
|
QuantizedTensor
Bas
e
,
dtype
:
torch
.
dtype
|
None
=
None
tensor
:
torch
.
Tensor
|
QuantizedTensor
Storag
e
,
dtype
:
torch
.
dtype
|
None
=
None
)
->
torch
.
Tensor
:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if
is_quantized_tensor
(
tensor
):
...
...
transformer_engine/pytorch/ops/basic/__init__.py
View file @
063ef88d
...
...
@@ -4,7 +4,19 @@
"""Single tensor operations supported by the operation fuser."""
from
.activation
import
GELU
,
GEGLU
,
QGELU
,
QGEGLU
,
ReLU
,
ReGLU
,
SReLU
,
SReGLU
,
SiLU
,
SwiGLU
from
.activation
import
(
GELU
,
GEGLU
,
QGELU
,
QGEGLU
,
ReLU
,
ReGLU
,
SReLU
,
SReGLU
,
SiLU
,
SwiGLU
,
ClampedSwiGLU
,
)
from
.add_extra_input
import
AddExtraInput
from
.all_gather
import
AllGather
from
.all_reduce
import
AllReduce
...
...
transformer_engine/pytorch/ops/basic/activation.py
View file @
063ef88d
...
...
@@ -28,6 +28,7 @@ __all__ = [
"SReGLU"
,
"SiLU"
,
"SwiGLU"
,
"ClampedSwiGLU"
,
]
...
...
@@ -392,3 +393,38 @@ class SwiGLU(_ActivationOperation):
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
dswiglu
(
*
args
,
**
kwargs
)
class
ClampedSwiGLU
(
_ActivationOperation
):
r
"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit: float
The clamp limit.
alpha: float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
def
__init__
(
self
,
*
,
limit
:
float
=
7.0
,
alpha
:
float
=
1.702
,
cache_quantized_input
:
bool
=
False
):
super
().
__init__
(
cache_quantized_input
=
cache_quantized_input
)
self
.
limit
=
limit
self
.
alpha
=
alpha
def
_activation_forward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
clamped_swiglu
(
*
args
,
limit
=
self
.
limit
,
alpha
=
self
.
alpha
,
**
kwargs
)
def
_activation_backward_impl
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
return
tex
.
clamped_dswiglu
(
*
args
,
limit
=
self
.
limit
,
alpha
=
self
.
alpha
,
**
kwargs
)
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
063ef88d
...
...
@@ -19,7 +19,7 @@ from ...distributed import (
gather_along_first_dim
,
reduce_scatter_along_first_dim
,
)
from
...
fp8
import
FP8GlobalStateManager
,
Recipe
from
...
quantization
import
FP8GlobalStateManager
,
Recipe
from
...module.base
import
(
_2X_ACC_FPROP
,
_2X_ACC_DGRAD
,
...
...
@@ -29,7 +29,7 @@ from ...module.base import (
)
from
...tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
from
...tensor.
_internal
.float8_tensor_
bas
e
import
Float8Tensor
Bas
e
from
...tensor.
storage
.float8_tensor_
storag
e
import
Float8Tensor
Storag
e
from
...utils
import
(
canonicalize_device
,
canonicalize_dtype
,
...
...
@@ -80,7 +80,9 @@ class BasicLinear(BasicOperation):
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with
Megatron-LM.
Megatron-LM. This argument along with weight tensor having
attribute 'overwrite_main_grad' set to True will overwrite
`main_grad` instead of accumulating.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
...
...
@@ -301,8 +303,8 @@ class BasicLinear(BasicOperation):
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because the weight was initialized "
"within
fp8
_model_init, but the forward pass was not "
"performed within
fp8_
autocast."
"within
quantized
_model_init, but the forward pass was not "
"performed within autocast."
)
quantizer
.
set_usage
(
rowwise
=
True
,
...
...
@@ -322,6 +324,20 @@ class BasicLinear(BasicOperation):
if
self
.
weight
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
)
->
None
:
super
().
pre_fuser_forward
(
requires_grad
=
requires_grad
)
if
FP8GlobalStateManager
.
is_fp8_enabled
():
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
weight_requires_grad
=
requires_grad
and
self
.
weight
.
requires_grad
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
def
reset_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
])
->
None
:
super
().
reset_recipe_state
(
recipe
=
recipe
)
...
...
@@ -352,6 +368,35 @@ class BasicLinear(BasicOperation):
and
not
getattr
(
self
,
"_with_quantized_weight"
,
False
)
)
# Recipe-specific configuration
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
if
recipe
is
not
None
:
if
recipe
.
float8_current_scaling
():
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_weight
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_weight
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
grad_output_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
if
getattr
(
self
,
"sequence_parallel"
,
False
):
tensor_parallel_mode
=
getattr
(
self
,
"tensor_parallel_mode"
,
None
)
if
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
elif
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
if
recipe
.
nvfp4
():
if
getattr
(
self
,
"sequence_parallel"
,
False
):
tensor_parallel_mode
=
getattr
(
self
,
"tensor_parallel_mode"
,
None
)
if
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
elif
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
@
staticmethod
def
_functional_forward
(
input
:
torch
.
Tensor
,
# pylint: disable=redefined-builtin
...
...
@@ -568,7 +613,7 @@ class BasicLinear(BasicOperation):
# Prepare input tensor for backward pass
if
weight_requires_grad
:
if
with_quantized_compute
and
is_quantized_tensor
(
x_local
):
if
not
(
isinstance
(
x_local
,
Float8Tensor
Bas
e
)
and
with_x_all_gather
):
if
not
(
isinstance
(
x_local
,
Float8Tensor
Storag
e
)
and
with_x_all_gather
):
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
...
...
@@ -731,7 +776,7 @@ class BasicLinear(BasicOperation):
if
with_quantized_compute
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
input_quantizer
.
set_usage
(
columnwise
=
True
)
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
if
with_x_all_gather
:
x
,
x_async
=
gather_along_first_dim
(
x_local
,
...
...
@@ -912,34 +957,13 @@ class BasicLinear(BasicOperation):
input_requires_grad
=
ctx
.
requires_grad
weight_requires_grad
=
ctx
.
requires_grad
and
self
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_output_quantizer
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
grad_output_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
# Get autocast dtype if needed
if
torch
.
is_autocast_enabled
():
...
...
@@ -997,6 +1021,7 @@ class BasicLinear(BasicOperation):
weight_param
=
self
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
accumulate_into_main_grad
=
not
getattr
(
weight_param
,
"overwrite_main_grad"
,
False
)
if
not
hasattr
(
weight_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
...
...
transformer_engine/pytorch/ops/basic/dropout.py
View file @
063ef88d
...
...
@@ -11,7 +11,7 @@ import torch
import
transformer_engine_torch
as
tex
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...tensor
import
Quantizer
from
...tensor.
_internal
.float8_tensor_
bas
e
import
Float8Tensor
Bas
e
from
...tensor.
storage
.float8_tensor_
storag
e
import
Float8Tensor
Storag
e
from
.._common
import
maybe_autocast_dtype
,
maybe_dequantize
from
..op
import
BasicOperation
,
OperationContext
...
...
@@ -56,7 +56,7 @@ class Dropout(BasicOperation):
out
=
input_
elif
impl
==
"fused"
:
x
=
input_
if
not
isinstance
(
x
,
Float8Tensor
Bas
e
):
if
not
isinstance
(
x
,
Float8Tensor
Storag
e
):
x
=
maybe_dequantize
(
x
,
dtype
=
dtype
)
out
,
mask
=
tex
.
dropout_fwd
(
x
,
self
.
dropout_probability
)
elif
impl
==
"unfused"
:
...
...
transformer_engine/pytorch/ops/basic/quantize.py
View file @
063ef88d
...
...
@@ -9,7 +9,7 @@ from typing import Optional
import
torch
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
.._common
import
is_quantized_tensor
from
..op
import
BasicOperation
,
OperationContext
from
...tensor
import
Quantizer
...
...
@@ -18,8 +18,8 @@ from ...tensor import Quantizer
class
Quantize
(
BasicOperation
):
"""Quantize tensor data
Uses
FP8
recipe from `
fp8_
autocast` context. When called outside
of an `
fp8_
autocast` context, this is an identity operation.
Uses recipe from `autocast` context. When called outside
of an `autocast` context, this is an identity operation.
Parameters
----------
...
...
transformer_engine/pytorch/ops/fused/backward_activation_bias.py
View file @
063ef88d
...
...
@@ -10,7 +10,7 @@ from typing import Optional
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.
fp8
import
Recipe
from
transformer_engine.pytorch.
quantization
import
Recipe
from
transformer_engine.pytorch.ops.basic
import
Bias
from
transformer_engine.pytorch.ops.basic.activation
import
(
_ActivationOperation
,
...
...
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
063ef88d
...
...
@@ -59,6 +59,7 @@ class BackwardLinearAdd(FusedOperation):
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
accumulate_into_main_grad
=
not
getattr
(
weight_param
,
"overwrite_main_grad"
,
False
)
if
not
hasattr
(
weight_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
...
...
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
View file @
063ef88d
...
...
@@ -60,6 +60,7 @@ class BackwardLinearScale(FusedOperation):
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
accumulate_into_main_grad
=
not
getattr
(
weight_param
,
"overwrite_main_grad"
,
False
)
if
not
hasattr
(
weight_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
063ef88d
...
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
import
torch
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
..basic
import
BasicLinear
,
Bias
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
...
...
@@ -85,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
063ef88d
...
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
import
torch
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
..basic
import
AddExtraInput
,
BasicLinear
,
Bias
from
..op
import
FusedOperation
,
FusibleOperation
,
OperationContext
...
...
@@ -79,7 +79,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
None
...
...
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
View file @
063ef88d
...
...
@@ -11,7 +11,7 @@ from typing import Any, Optional
import
torch
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
...tensor
import
Quantizer
from
..basic
import
AddExtraInput
,
BasicLinear
,
ConstantScale
from
..op
import
(
...
...
@@ -58,7 +58,7 @@ class ForwardLinearScaleAdd(FusedOperation):
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
None
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
063ef88d
...
...
@@ -523,6 +523,7 @@ class UserbuffersBackwardLinear(FusedOperation):
weight_param
=
linear_op
.
weight
if
hasattr
(
weight_param
,
"__fsdp_param__"
):
weight_param
.
main_grad
=
weight_param
.
get_main_grad
()
accumulate_into_main_grad
=
not
getattr
(
weight_param
,
"overwrite_main_grad"
,
False
)
if
not
hasattr
(
weight_param
,
"main_grad"
):
raise
RuntimeError
(
"BasicLinear op is configured with "
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
063ef88d
...
...
@@ -14,7 +14,7 @@ from transformer_engine_torch import CommOverlapType
from
...cpp_extensions
import
general_gemm
from
...cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...distributed
import
get_distributed_world_size
from
...
fp8
import
FP8GlobalStateManager
from
...
quantization
import
FP8GlobalStateManager
from
...module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_ub
,
...
...
@@ -23,7 +23,7 @@ from ...module.base import (
)
from
...tensor.quantized_tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
...tensor.
_internal
.float8_tensor_
bas
e
import
Float8Tensor
Bas
e
from
...tensor.
storage
.float8_tensor_
storag
e
import
Float8Tensor
Storag
e
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
from
..op
import
(
...
...
@@ -267,7 +267,7 @@ class UserbuffersForwardLinear(FusedOperation):
# Prepare input tensor for backward pass
if
weight_requires_grad
:
if
with_quantized_compute
and
is_quantized_tensor
(
x_local
):
if
not
(
isinstance
(
x_local
,
Float8Tensor
Bas
e
)
and
with_ub_all_gather
):
if
not
(
isinstance
(
x_local
,
Float8Tensor
Storag
e
)
and
with_ub_all_gather
):
# FP8 does not support all-gather of transpose data
x_local
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
else
:
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
063ef88d
...
...
@@ -11,7 +11,7 @@ import itertools
import
torch
from
transformer_engine.pytorch.
fp8
import
FP8GlobalStateManager
,
Recipe
,
DelayedScaling
from
transformer_engine.pytorch.
quantization
import
FP8GlobalStateManager
,
Recipe
,
DelayedScaling
from
transformer_engine.pytorch.ops.op
import
(
BasicOperation
,
FusibleOperation
,
...
...
@@ -472,6 +472,10 @@ class OperationFuser:
# Attempt to fuse operations if neccesary
self
.
maybe_fuse_ops
(
is_grad_enabled
,
recipe
,
input
,
basic_op_extra_inputs
)
# Initialization before forward
for
idx
,
op
in
enumerate
(
self
.
_basic_ops
):
op
.
pre_fuser_forward
(
requires_grad
=
idx
>=
self
.
first_op_requiring_backward
)
# Fuser forward pass
if
is_grad_enabled
:
forward_func
=
_OperationFuserAutogradFunction
.
apply
...
...
transformer_engine/pytorch/ops/op.py
View file @
063ef88d
...
...
@@ -14,10 +14,10 @@ from typing import Any, Optional
import
torch
from
transformer_engine.common.recipe
import
Recipe
from
..
fp8
import
(
from
..
quantization
import
(
FP8GlobalStateManager
,
RecipeState
,
fp8_
autocast
,
autocast
,
)
from
..tensor
import
Quantizer
...
...
@@ -65,6 +65,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def
pre_first_fuser_forward
(
self
)
->
None
:
"""Preprocessing before first fuser forward pass"""
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
,
# pylint: disable=unused-argument
)
->
None
:
"""Preprocessing before fuser forward pass"""
def
get_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
"""Get builder class for quantized input tensor"""
...
...
@@ -588,6 +595,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
extra
[
key
]
=
val
state
[
mode
][
"extra_fp8_variables"
]
=
extra
if
not
state
:
return
torch
.
empty
(
0
,
dtype
=
torch
.
uint8
)
# Serialize state into byte tensor
torch
.
cuda
.
synchronize
()
state_serialized
=
bytearray
(
pickle
.
dumps
(
state
))
...
...
@@ -624,7 +634,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
if
self
.
_fp8_metas
is
None
or
self
.
_fp8_metas
[
mode
]
is
None
:
with
fp8_
autocast
(
fp8_
recipe
=
state
[
mode
][
"recipe"
]):
with
autocast
(
recipe
=
state
[
mode
][
"recipe"
]):
self
.
reset_recipe_state
(
recipe
=
state
[
mode
][
"recipe"
])
fp8_meta
=
self
.
_fp8_metas
[
mode
]
...
...
@@ -710,6 +720,10 @@ class FusedOperation(FusibleOperation):
for
op
in
self
.
basic_ops
:
op
.
pre_first_fuser_forward
()
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
)
->
None
:
for
op
in
self
.
basic_ops
:
op
.
pre_fuser_forward
(
requires_grad
=
requires_grad
)
def
forward
(
self
,
input
:
torch
.
Tensor
,
# pylint: disable=redefined-builtin
...
...
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment