Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
53fa872c
Commit
53fa872c
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_release_v2.8' into release_v2.8
parents
27ddce40
40c69e75
Changes
159
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1579 additions
and
52 deletions
+1579
-52
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+35
-8
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+44
-4
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+36
-9
transformer_engine/pytorch/ops/basic/basic_linear.py
transformer_engine/pytorch/ops/basic/basic_linear.py
+45
-23
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+1
-1
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+1
-1
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+1
-1
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+4
-0
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+11
-0
transformer_engine/pytorch/tensor/__init__.py
transformer_engine/pytorch/tensor/__init__.py
+3
-0
transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py
...rmer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py
+348
-0
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+4
-0
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+2
-3
transformer_engine/pytorch/tensor/nvfp4_tensor.py
transformer_engine/pytorch/tensor/nvfp4_tensor.py
+898
-0
transformer_engine/pytorch/tensor/quantized_tensor.py
transformer_engine/pytorch/tensor/quantized_tensor.py
+4
-0
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+20
-1
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+14
-0
transformer_engine/pytorch/triton/pad.py
transformer_engine/pytorch/triton/pad.py
+94
-0
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+14
-1
No files found.
transformer_engine/pytorch/module/layernorm_linear.py
View file @
53fa872c
...
...
@@ -16,6 +16,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_experimental
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
...
...
@@ -29,6 +30,7 @@ from .base import (
from
..fp8
import
FP8GlobalStateManager
from
..utils
import
(
assert_dim_for_fp8_exec
,
assert_dim_for_all_gather
,
cast_if_needed
,
clear_tensor_data
,
divide
,
...
...
@@ -54,7 +56,7 @@ from ..distributed import (
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
._common
import
apply_normalization
,
noop_cat
,
WeightGradStore
from
._common
import
apply_normalization
,
noop_cat
,
WeightGradStore
,
get_module_quantizers
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensorBase
,
...
...
@@ -143,6 +145,8 @@ class _LayerNormLinear(torch.autograd.Function):
if
ub_name
is
not
None
:
nvtx_label
=
f
"
{
nvtx_label
}
.
{
ub_name
}
"
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
# Make sure input dimensions are compatible
out_features
,
in_features
=
weight
.
shape
inp_shape
=
inp
.
shape
...
...
@@ -152,6 +156,7 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat
=
inp
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_all_gather
(
inputmat
,
with_input_all_gather
,
input_quantizer
)
# Cast for native AMP
nvtx_range_push
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
...
...
@@ -165,7 +170,6 @@ class _LayerNormLinear(torch.autograd.Function):
weight_requires_grad
=
weight
.
requires_grad
backward_needs_input
=
is_grad_enabled
and
weight_requires_grad
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
# Configure Userbuffers communication (comm+GEMM overlap)
if
debug
:
# turn off userbuffers in debug mode
...
...
@@ -198,11 +202,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
experimental
=
is_experimental
(
input_quantizer
)
with_quantized_norm
=
(
fp8
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
experimental
)
# Apply normalization
...
...
@@ -248,7 +254,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer
=
None
if
fp8
or
debug
:
quantizer
=
input_quantizer
if
not
with_quantized_norm
:
# experimental recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
experimental
:
ln_out
=
quantizer
(
ln_out
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
...
...
@@ -1462,6 +1469,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif other recipes (mxfp8, etc)
def
reset_layer_norm_parameters
(
self
)
->
None
:
...
...
@@ -1566,11 +1575,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
# Get concatenated weight and bias tensors
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
)
quantizers
=
get_module_quantizers
(
self
,
fp8_output
,
fp8_grad
,
debug
)
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
...
...
@@ -1803,6 +1808,28 @@ class LayerNormLinear(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_customize_quantizers_nvfp4
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert
recipe
.
nvfp4
(),
"Incorrect recipe."
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
# set input_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# customize grad_output_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
"""Get the weight tensors of the module."""
unfused_weights
=
[
getattr
(
self
,
name
)
for
name
in
self
.
weight_names
]
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
53fa872c
...
...
@@ -18,6 +18,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_experimental
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
...
...
@@ -41,6 +42,7 @@ from ..utils import (
init_method_constant
,
cast_if_needed
,
assert_dim_for_fp8_exec
,
assert_dim_for_all_gather
,
clear_tensor_data
,
requires_grad
,
needs_quantized_gemm
,
...
...
@@ -65,6 +67,7 @@ from ..tensor.float8_tensor import (
Float8Tensor
,
)
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.nvfp4_tensor
import
NVFP4Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
._common
import
apply_normalization
,
WeightGradStore
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
...
...
@@ -121,7 +124,8 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: []
if
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
():
# TODO(ksivaman): Fuse nvfp4 act once kernel is available.
if
recipe
.
float8_current_scaling
()
or
recipe
.
float8_block_scaling
()
or
recipe
.
nvfp4
():
return
{
"gelu"
:
(
tex
.
gelu
,
tex
.
dgelu
,
None
),
"geglu"
:
(
tex
.
geglu
,
tex
.
dgeglu
,
None
),
...
...
@@ -218,6 +222,7 @@ class _LayerNormMLP(torch.autograd.Function):
inputmat
=
inp
.
view
((
-
1
,
in_features
))
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
fc1_weight
,
fc2_weight
)
assert_dim_for_all_gather
(
inputmat
,
sequence_parallel
,
fc1_input_quantizer
)
activation_func
=
_act_func
(
activation
,
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
...
...
@@ -265,11 +270,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
experimental
=
is_experimental
(
fc1_input_quantizer
)
with_quantized_norm
=
(
fp8
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
experimental
)
# Apply normalization
...
...
@@ -309,7 +316,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer
=
None
if
fp8
or
debug
:
quantizer
=
fc1_input_quantizer
if
not
with_quantized_norm
:
# experimental recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
experimental
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
...
...
@@ -555,6 +563,7 @@ class _LayerNormMLP(torch.autograd.Function):
if
not
fc2_weight
.
requires_grad
:
clear_tensor_data
(
act_out
)
act_out
=
None
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
inputmat
,
ln_weight
,
...
...
@@ -680,6 +689,7 @@ class _LayerNormMLP(torch.autograd.Function):
mu
,
rsigma
,
)
=
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx
.
tensor_objects
=
None
...
...
@@ -1029,7 +1039,10 @@ class _LayerNormMLP(torch.autograd.Function):
if
ctx
.
fp8
:
# TODO float8 blockwise current scaling has no bgrad fusion for now
if
isinstance
(
ctx
.
fc1_grad_output_quantizer
,
Float8BlockQuantizer
):
# TODO(ksivaman): Re-add fusion once kernel is available.
if
isinstance
(
ctx
.
fc1_grad_output_quantizer
,
(
Float8BlockQuantizer
,
NVFP4Quantizer
)
):
fc1_bias_grad
=
dact
.
view
(
-
1
,
dact
.
shape
[
-
1
]).
sum
(
dim
=
0
)
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
else
:
...
...
@@ -1718,6 +1731,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
def
reset_layer_norm_parameters
(
self
)
->
None
:
...
...
@@ -1937,7 +1952,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_input_quantizer
=
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM2_INPUT
]
fc2_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
)),
columnwise
=
isinstance
(
fc2_input_quantizer
,
(
MXFP8Quantizer
,
Float8BlockQuantizer
,
NVFP4Quantizer
),
),
)
fc1_input_quantizer
.
internal
=
True
if
fp8_output
:
...
...
@@ -2142,6 +2160,28 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
amax_reduction_group
=
self
.
tp_group
def
_customize_quantizers_nvfp4
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert
recipe
.
nvfp4
(),
"Incorrect recipe."
if
fwd
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
if
self
.
sequence_parallel
and
self
.
set_parallel_mode
:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_tensors
(
self
)
->
List
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]]:
"""Get the weight tensors of the module."""
return
[
self
.
fc1_weight
,
self
.
fc2_weight
]
...
...
transformer_engine/pytorch/module/linear.py
View file @
53fa872c
...
...
@@ -26,7 +26,7 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
)
from
._common
import
noop_cat
,
WeightGradStore
from
._common
import
noop_cat
,
WeightGradStore
,
get_module_quantizers
from
..fp8
import
FP8GlobalStateManager
from
..utils
import
(
cast_if_needed
,
...
...
@@ -36,6 +36,7 @@ from ..utils import (
requires_grad
,
needs_quantized_gemm
,
assert_dim_for_fp8_exec
,
assert_dim_for_all_gather
,
nvtx_range_pop
,
nvtx_range_push
,
get_activation_offloading
,
...
...
@@ -67,6 +68,7 @@ from ..tensor.quantized_tensor import (
)
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.utils
import
is_experimental
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
...debug.pytorch.debug_state
import
TEDebugState
...
...
@@ -153,6 +155,9 @@ class _Linear(torch.autograd.Function):
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
# experimental recipe check
experimental
=
is_experimental
(
input_quantizer
)
or
is_experimental
(
weight_quantizer
)
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
...
...
@@ -163,6 +168,7 @@ class _Linear(torch.autograd.Function):
own_quantized_input
=
False
if
fp8
:
assert_dim_for_fp8_exec
(
inputmat
,
weight
)
assert_dim_for_all_gather
(
inputmat
,
with_input_all_gather_nccl
,
input_quantizer
)
if
save_original_input
:
assert
not
isinstance
(
input_quantizer
,
Float8Quantizer
...
...
@@ -174,7 +180,7 @@ class _Linear(torch.autograd.Function):
if
fp8
or
debug
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
not
isinstance
(
inputmat
,
QuantizedTensorBase
):
if
not
isinstance
(
inputmat
,
QuantizedTensorBase
)
and
not
experimental
:
own_quantized_input
=
True
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
if
isinstance
(
...
...
@@ -464,6 +470,7 @@ class _Linear(torch.autograd.Function):
ctx
.
main_grad_func
=
lambda
:
weight
.
main_grad
ctx
.
debug
=
debug
ctx
.
experimental
=
experimental
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_bias
=
bias
is
not
None
...
...
@@ -633,7 +640,7 @@ class _Linear(torch.autograd.Function):
if
isinstance
(
inputmat
,
QuantizedTensorBase
):
# Input tensor is already quantized
pass
elif
ctx
.
debug
:
elif
ctx
.
debug
or
ctx
.
experimental
:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else
:
...
...
@@ -722,6 +729,7 @@ class _Linear(torch.autograd.Function):
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push
(
f
"
{
nvtx_label
}
.dgrad_gemm"
)
gemm_out
,
*
_
,
reduce_scatter_out
=
general_gemm
(
weight_fp8
,
...
...
@@ -1353,6 +1361,8 @@ class Linear(TransformerEngineBaseModule):
self
.
_customize_quantizers_float8_current_scaling
(
fwd
,
recipe
)
elif
recipe
.
float8_block_scaling
():
self
.
_customize_quantizers_float8_blockwise_scaling
(
fwd
,
recipe
)
elif
recipe
.
nvfp4
():
self
.
_customize_quantizers_nvfp4
(
fwd
,
recipe
)
# elif for other recipes (mxfp8, etc.)
def
reset_parameters
(
self
,
defer_init
=
False
):
...
...
@@ -1437,12 +1447,7 @@ class Linear(TransformerEngineBaseModule):
weight_tensor
,
bias_tensor
=
self
.
_get_weight_and_bias_tensors
()
quantizers
=
(
self
.
_get_quantizers
(
fp8_output
,
fp8_grad
)
if
not
debug
else
self
.
_get_debug_quantizers
(
fp8_output
,
fp8_grad
)
)
quantizers
=
get_module_quantizers
(
self
,
fp8_output
,
fp8_grad
,
debug
)
if
debug
:
if
self
.
no_debug_features_active
(
quantizers
):
debug
=
False
...
...
@@ -1682,6 +1687,28 @@ class Linear(TransformerEngineBaseModule):
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_customize_quantizers_nvfp4
(
self
,
fwd
:
bool
,
recipe
:
Recipe
)
->
None
:
"""Customize quantizers based on current scaling recipe + linear."""
assert
recipe
.
nvfp4
(),
"Incorrect recipe."
if
fwd
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"column"
:
# customize input_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
else
:
if
self
.
sequence_parallel
and
self
.
parallel_mode
==
"row"
:
# customize grad_output_quantizer with amax reduction TP group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
with_amax_reduction
=
True
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
def
_get_weight_quantizers
(
self
)
->
List
[
Quantizer
]:
"""Get the weight quantizers of the module."""
if
not
self
.
fp8
and
not
self
.
fp8_calibration
:
...
...
transformer_engine/pytorch/ops/basic/basic_linear.py
View file @
53fa872c
...
...
@@ -322,6 +322,20 @@ class BasicLinear(BasicOperation):
if
self
.
weight
.
device
.
type
==
"meta"
:
self
.
reset_parameters
()
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
)
->
None
:
super
().
pre_fuser_forward
(
requires_grad
=
requires_grad
)
if
FP8GlobalStateManager
.
is_fp8_enabled
():
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
weight_requires_grad
=
requires_grad
and
self
.
weight
.
requires_grad
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
grad_output_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
def
reset_recipe_state
(
self
,
*
,
recipe
:
Optional
[
Recipe
])
->
None
:
super
().
reset_recipe_state
(
recipe
=
recipe
)
...
...
@@ -352,6 +366,35 @@ class BasicLinear(BasicOperation):
and
not
getattr
(
self
,
"_with_quantized_weight"
,
False
)
)
# Recipe-specific configuration
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
if
recipe
is
not
None
:
if
recipe
.
float8_current_scaling
():
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_weight
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_weight
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
grad_output_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
if
getattr
(
self
,
"sequence_parallel"
,
False
):
tensor_parallel_mode
=
getattr
(
self
,
"tensor_parallel_mode"
,
None
)
if
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
elif
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
if
recipe
.
nvfp4
():
if
getattr
(
self
,
"sequence_parallel"
,
False
):
tensor_parallel_mode
=
getattr
(
self
,
"tensor_parallel_mode"
,
None
)
if
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
elif
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
@
staticmethod
def
_functional_forward
(
input
:
torch
.
Tensor
,
# pylint: disable=redefined-builtin
...
...
@@ -731,7 +774,7 @@ class BasicLinear(BasicOperation):
if
with_quantized_compute
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
input_quantizer
.
set_usage
(
columnwise
=
True
)
input_quantizer
.
set_usage
(
rowwise
=
False
,
columnwise
=
True
)
if
with_x_all_gather
:
x
,
x_async
=
gather_along_first_dim
(
x_local
,
...
...
@@ -912,34 +955,13 @@ class BasicLinear(BasicOperation):
input_requires_grad
=
ctx
.
requires_grad
weight_requires_grad
=
ctx
.
requires_grad
and
self
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
self
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
self
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
grad_output_quantizer
=
self
.
get_quantizer
(
"backward"
,
0
)
grad_input_quantizer
=
prev_op_grad_output_quantizer
with_quantized_compute
=
FP8GlobalStateManager
.
is_fp8_enabled
()
if
with_quantized_compute
:
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
weight_requires_grad
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_current_scaling
():
input_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
input_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
weight_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
weight_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
grad_output_quantizer
.
force_pow_2_scales
=
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
grad_output_quantizer
.
amax_epsilon_scales
=
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"column"
:
input_quantizer
.
with_amax_reduction
=
True
input_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
if
self
.
sequence_parallel
and
self
.
tensor_parallel_mode
==
"row"
:
grad_output_quantizer
.
with_amax_reduction
=
True
grad_output_quantizer
.
amax_reduction_group
=
self
.
tensor_parallel_group
# Get autocast dtype if needed
if
torch
.
is_autocast_enabled
():
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
53fa872c
...
...
@@ -85,7 +85,7 @@ class ForwardLinearBiasActivation(FusedOperation):
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
next_op_input_quantizer
...
...
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
53fa872c
...
...
@@ -79,7 +79,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
None
...
...
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
View file @
53fa872c
...
...
@@ -58,7 +58,7 @@ class ForwardLinearScaleAdd(FusedOperation):
input_requires_grad
=
linear_op_ctx
.
requires_grad
weight_requires_grad
=
linear_op_ctx
.
requires_grad
and
linear_op
.
weight
.
requires_grad
#
FP8 metadata
#
Quantizers
input_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
0
)
weight_quantizer
=
linear_op
.
get_quantizer
(
"forward"
,
1
)
output_quantizer
=
None
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
53fa872c
...
...
@@ -472,6 +472,10 @@ class OperationFuser:
# Attempt to fuse operations if neccesary
self
.
maybe_fuse_ops
(
is_grad_enabled
,
recipe
,
input
,
basic_op_extra_inputs
)
# Initialization before forward
for
idx
,
op
in
enumerate
(
self
.
_basic_ops
):
op
.
pre_fuser_forward
(
requires_grad
=
idx
>=
self
.
first_op_requiring_backward
)
# Fuser forward pass
if
is_grad_enabled
:
forward_func
=
_OperationFuserAutogradFunction
.
apply
...
...
transformer_engine/pytorch/ops/op.py
View file @
53fa872c
...
...
@@ -65,6 +65,13 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def
pre_first_fuser_forward
(
self
)
->
None
:
"""Preprocessing before first fuser forward pass"""
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
,
# pylint: disable=unused-argument
)
->
None
:
"""Preprocessing before fuser forward pass"""
def
get_input_quantizer
(
self
)
->
Optional
[
Quantizer
]:
"""Get builder class for quantized input tensor"""
...
...
@@ -710,6 +717,10 @@ class FusedOperation(FusibleOperation):
for
op
in
self
.
basic_ops
:
op
.
pre_first_fuser_forward
()
def
pre_fuser_forward
(
self
,
*
,
requires_grad
:
bool
)
->
None
:
for
op
in
self
.
basic_ops
:
op
.
pre_fuser_forward
(
requires_grad
=
requires_grad
)
def
forward
(
self
,
input
:
torch
.
Tensor
,
# pylint: disable=redefined-builtin
...
...
transformer_engine/pytorch/tensor/__init__.py
View file @
53fa872c
...
...
@@ -54,6 +54,7 @@ def get_all_tensor_types():
Float8BlockwiseQTensor
,
Float8BlockwiseQTensorBase
,
)
from
transformer_engine.pytorch.tensor.nvfp4_tensor
import
NVFP4Tensor
,
NVFP4TensorBase
all_tensor_types
=
[
torch
.
Tensor
,
...
...
@@ -64,5 +65,7 @@ def get_all_tensor_types():
MXFP8TensorBase
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensorBase
,
NVFP4Tensor
,
NVFP4TensorBase
,
]
return
all_tensor_types
transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Mixin class holding data specific for NVFP4Tensor"""
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
functools
import
math
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
warnings
import
torch
# import transformer_engine_torch as tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensorBase
# from ...constants import TE_DType as torch_to_transformer_engine_dtype
from
..quantized_tensor
import
Quantizer
from
...utils
import
_empty_tensor
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_fp4_e2m1_vals
(
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
"""Values representable in FP4 E2M1 format"""
return
torch
.
tensor
(
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
,
-
0.0
,
-
0.5
,
-
1.0
,
-
1.5
,
-
2.0
,
-
3.0
,
-
4.0
,
-
6.0
],
device
=
device
,
dtype
=
dtype
,
)
class
_FromNVFP4Func
(
torch
.
autograd
.
Function
):
"""Cast from NVFP4 to other dtype"""
@
staticmethod
def
forward
(
_ctx
:
Optional
[
torch
.
autograd
.
function
.
FunctionCtx
],
# unused
tensor
:
NVFP4TensorBase
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# Dequantize row-wise data
if
tensor
.
_rowwise_data
is
not
None
:
### TODO(tmoon): Debug dequantize kernel and remove unfused impl
# return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
# Tensor properties
shape
=
list
(
tensor
.
_rowwise_data
.
size
())
shape
[
-
1
]
*=
2
device
=
tensor
.
_rowwise_data
.
device
# Convert FP4E2M1 values to FP32
data
=
tensor
.
_rowwise_data
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
data
=
torch
.
stack
((
data
&
0x0F
,
data
>>
4
),
dim
=-
1
).
reshape
(
shape
)
data
=
_fp4_e2m1_vals
(
device
,
dtype
=
torch
.
float32
)[
data
]
data
=
data
.
to
(
torch
.
float32
).
contiguous
()
# Convert FP8E4M3 block scales to FP32
block_scales
=
tensor
.
_rowwise_scale_inv
block_scales
=
block_scales
.
reshape
(
-
1
,
block_scales
.
size
(
-
1
))
block_scales
=
block_scales
[:
math
.
prod
(
shape
[:
-
1
]),
:
shape
[
-
1
]
//
16
]
block_scales
=
block_scales
.
view
(
torch
.
float8_e4m3fn
).
to
(
torch
.
float32
)
# Convert amax to FP32 tensor scale
tensor_scale
=
tensor
.
_amax_rowwise
/
(
6.0
*
448.0
)
# Scale by FP4E2M1 and FP8E4M3 max
# Apply scales
block_data
=
data
.
view
(
-
1
,
16
)
block_data
*=
tensor_scale
.
view
(())
*
block_scales
.
reshape
(
-
1
,
1
)
return
data
.
to
(
dtype
)
if
tensor
.
_columnwise_data
is
not
None
:
raise
NotImplementedError
(
"Dequantizing column-wise NVFP4 data is not implemented yet!"
)
raise
ValueError
(
"Attempted to dequantize NVFP4 tensor with no data"
)
@
staticmethod
def
backward
(
_ctx
:
torch
.
autograd
.
function
.
FunctionCtx
,
# unused
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return
grad
,
None
class
NVFP4TensorBase
(
QuantizedTensorBase
):
"""Mixin class that holds data attributes of NVFP4Tensor.
NVFP4Tensor inherits from the PyTorch tensor class and this mixin
class. If this class is instantiated directly, it has the same
data, lower CPU overhead, and less functionality. It should only
be instantiated directly for performance-critical internal usage.
"""
_rowwise_data
:
Optional
[
torch
.
Tensor
]
_columnwise_data
:
Optional
[
torch
.
Tensor
]
_quantizer
:
Optional
[
Quantizer
]
_rowwise_scale_inv
:
torch
.
Tensor
_columnwise_scale_inv
:
torch
.
Tensor
_fp4_dtype
:
TE_DType
_amax_rowwise
:
torch
.
Tensor
_amax_columnwise
:
torch
.
Tensor
def
__new__
(
cls
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
rowwise_scale_inv
:
torch
.
Tensor
,
columnwise_data
:
Optional
[
torch
.
Tensor
],
columnwise_scale_inv
:
torch
.
Tensor
,
amax_rowwise
:
torch
.
Tensor
,
amax_columnwise
:
torch
.
Tensor
,
fp4_dtype
:
TE_DType
,
quantizer
:
Optional
[
Quantizer
],
*
args
,
**
kwargs
,
):
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
.
_rowwise_data
=
rowwise_data
instance
.
_columnwise_data
=
columnwise_data
instance
.
_fp4_dtype
=
fp4_dtype
instance
.
_quantizer
=
quantizer
.
copy
()
if
quantizer
is
not
None
else
None
instance
.
_rowwise_scale_inv
=
rowwise_scale_inv
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
instance
.
_amax_rowwise
=
amax_rowwise
instance
.
_amax_columnwise
=
amax_columnwise
return
instance
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for
t
in
(
self
.
_rowwise_data
,
self
.
_columnwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_scale_inv
,
self
.
_amax_rowwise
,
self
.
_amax_columnwise
,
):
if
t
is
not
None
:
t
.
data
=
_empty_tensor
()
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
return
{
"rowwise_data"
:
self
.
_rowwise_data
,
"rowwise_scale_inv"
:
self
.
_rowwise_scale_inv
,
"columnwise_data"
:
self
.
_columnwise_data
,
"columnwise_scale_inv"
:
self
.
_columnwise_scale_inv
,
"amax_rowwise"
:
self
.
_amax_rowwise
,
"amax_columnwise"
:
self
.
_amax_columnwise
,
"fp4_dtype"
:
self
.
_fp4_dtype
,
"quantizer"
:
self
.
_quantizer
,
}
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
NVFP4TensorBase
]:
"""Prepare the tensor base for saving for backward"""
tensors
=
[
self
.
_rowwise_data
,
self
.
_columnwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_scale_inv
,
self
.
_amax_rowwise
,
self
.
_amax_columnwise
,
]
self
.
_rowwise_data
=
None
self
.
_columnwise_data
=
None
self
.
_rowwise_scale_inv
=
None
self
.
_columnwise_scale_inv
=
None
self
.
_amax_rowwise
=
None
self
.
_amax_columnwise
=
None
return
tensors
,
self
def
restore_from_saved
(
self
,
tensors
:
list
[
Optional
[
torch
.
Tensor
]]
)
->
list
[
Optional
[
torch
.
Tensor
]]:
"""Restore the tensor base data from the saved tensors list."""
self
.
_rowwise_data
=
tensors
[
0
]
self
.
_columnwise_data
=
tensors
[
1
]
self
.
_rowwise_scale_inv
=
tensors
[
2
]
self
.
_columnwise_scale_inv
=
tensors
[
3
]
self
.
_amax_rowwise
=
tensors
[
4
]
self
.
_amax_columnwise
=
tensors
[
5
]
return
tensors
[
6
:]
def
get_data_tensors
(
self
):
"""Get this Tensor's data."""
return
self
.
_rowwise_data
,
self
.
_columnwise_data
def
dequantize
(
self
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
"""Dequantize to a higher precision."""
return
_FromNVFP4Func
.
forward
(
None
,
self
,
dtype
)
def
size
(
self
,
dim
:
Optional
[
int
]
=
None
)
->
Union
[
torch
.
Size
,
int
]:
# pylint: disable=missing-function-docstring
# Infer tensor shape
shape
=
None
if
self
.
_rowwise_data
is
not
None
:
byte_shape
=
list
(
self
.
_rowwise_data
.
size
())
shape
=
byte_shape
[:
-
1
]
+
[
byte_shape
[
-
1
]
*
2
]
elif
self
.
_columnwise_data
is
not
None
:
warnings
.
warn
(
"Attempting to get shape of NVFP4 tensor with only column-wise data."
)
byte_shape
=
list
(
self
.
_columnwise_data
.
size
())
shape
=
byte_shape
[
1
:
-
1
]
+
[
byte_shape
[
-
1
]
*
2
,
byte_shape
[
0
]]
if
shape
is
None
:
raise
RuntimeError
(
"Attempted to get shape of NVFP4 tensor with no data"
)
# Return shape or dim
if
dim
is
None
:
return
torch
.
Size
(
shape
)
return
shape
[
dim
]
def
view
(
self
,
shape
:
torch
.
Size
):
# pylint: disable=missing-function-docstring
# Return input tensor if view not needed
cur_shape
=
self
.
size
()
if
shape
is
None
or
shape
==
cur_shape
:
return
self
# Canonicalize shape
if
not
isinstance
(
shape
,
Iterable
):
shape
=
[
shape
]
elif
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
Iterable
):
shape
=
shape
[
0
]
if
-
1
in
shape
:
shape
=
list
(
shape
)
d_inferred
=
-
math
.
prod
(
cur_shape
)
//
math
.
prod
(
shape
)
for
i
,
d
in
enumerate
(
shape
):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
shape
[
-
1
]
!=
cur_shape
[
-
1
]:
raise
RuntimeError
(
"NVFP4Tensor does not support reshaping inner dimension "
f
"(attempted to reshape dims=
{
tuple
(
cur_shape
)
}
to
{
tuple
(
shape
)
}
)"
)
# Reshape data
new_rowwise_data
=
None
new_columnwise_data
=
None
if
self
.
_rowwise_data
is
not
None
:
if
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
list
(
shape
[:
-
1
])
+
[
shape
[
-
1
]
//
2
]
new_rowwise_data
=
self
.
_rowwise_data
.
view
(
byte_shape
)
if
self
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
shape
[
-
1
],
math
.
prod
(
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
self
.
_columnwise_data
.
view
(
byte_shape
)
# Construct tensor
return
NVFP4TensorBase
(
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
self
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
self
.
_columnwise_scale_inv
,
amax_rowwise
=
self
.
_amax_rowwise
,
amax_columnwise
=
self
.
_amax_columnwise
,
quantizer
=
self
.
_quantizer
,
fp4_dtype
=
self
.
_fp4_dtype
,
)
def
__repr__
(
self
):
data_rowwise
=
self
.
dequantize
()
return
(
"NVFP4TensorBase("
f
"rowwise_scaled_data=
{
data_rowwise
}
,"
f
"rowwise_scale_inv=
{
self
.
_rowwise_scale_inv
}
,"
f
"amax_rowwise=
{
self
.
_amax_rowwise
}
,"
f
"amax_columnwise=
{
self
.
_amax_columnwise
}
,"
")"
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
"""
For the NVFP4 format, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if
rowwise_usage
is
None
:
rowwise_usage
=
self
.
_rowwise_data
is
not
None
if
columnwise_usage
is
None
:
columnwise_usage
=
self
.
_columnwise_data
is
not
None
# Update row-scaled data
if
rowwise_usage
:
if
self
.
_rowwise_data
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but NVFP4Tensor is missing row-scaled NVFP4 data"
)
if
self
.
_rowwise_scale_inv
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but NVFP4Tensor is missing row-scaled scale-inverses"
)
if
self
.
_amax_rowwise
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but NVFP4Tensor is missing per tensor"
" row-scaled scale-inverse"
)
else
:
self
.
_rowwise_data
=
None
self
.
_rowwise_scale_inv
=
None
self
.
_amax_rowwise
=
None
# Update column-scaled data
if
columnwise_usage
:
if
self
.
_columnwise_data
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, but NVFP4Tensor is missing column-scaled FP8 data"
)
if
self
.
_columnwise_scale_inv
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, "
"but NVFP4Tensor is missing column-scaled scale-inverses"
)
if
self
.
_amax_columnwise
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, "
"but NVFP4Tensor is missing per tensor column-scaled scale-inverse"
)
else
:
self
.
_columnwise_data
=
None
self
.
_columnwise_scale_inv
=
None
self
.
_amax_columnwise
=
None
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
53fa872c
...
...
@@ -216,6 +216,8 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax
:
torch
.
Tensor
"""FP8 datatype"""
dtype
:
TE_DType
"""amax update options"""
use_existing_amax
:
bool
"""amax reduction options"""
with_amax_reduction
:
bool
amax_reduction_group
:
Optional
[
dist_group_type
]
...
...
@@ -230,6 +232,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
*
,
rowwise
:
bool
=
True
,
columnwise
:
bool
=
True
,
use_existing_amax
:
bool
=
False
,
with_amax_reduction
:
bool
=
False
,
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
force_pow_2_scales
:
bool
=
False
,
...
...
@@ -239,6 +242,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
self
.
scale
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8_tensorwise
else
fp8_dtype
self
.
use_existing_amax
=
use_existing_amax
self
.
with_amax_reduction
=
with_amax_reduction
self
.
amax_reduction_group
=
amax_reduction_group
self
.
force_pow_2_scales
=
force_pow_2_scales
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
53fa872c
...
...
@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Tensor class with FP8 data"""
"""Tensor class with
MX
FP8 data"""
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
math
...
...
@@ -186,8 +186,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
precision. Can be inferred from fp8_meta if
provided.
precision.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype.
...
...
transformer_engine/pytorch/tensor/nvfp4_tensor.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with NVFP4 data"""
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
math
from
typing
import
Optional
,
Tuple
,
Union
import
functools
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
,
Recipe
from
..constants
import
NVFP4_BLOCK_SCALING_SIZE
,
dist_group_type
from
..utils
import
(
canonicalize_process_group
,
devices_match
,
round_up_to_nearest_multiple
,
)
from
._internal.nvfp4_tensor_base
import
NVFP4TensorBase
,
_FromNVFP4Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
aten
=
torch
.
ops
.
aten
def
get_no_random_sign_vector
()
->
torch
.
Tensor
:
"""Non-random sign vector for Hadamard transform."""
return
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
)
def
get_sign_from_vector
(
vector
:
torch
.
Tensor
)
->
int
:
"""Convert sign vector to bitmask.
Used for random Hadamard transform.
"""
mask
=
0
for
i
,
v
in
enumerate
(
vector
):
mask
|=
(
v
==
-
1
)
<<
i
return
mask
def
get_wgrad_sign_vector
()
->
torch
.
Tensor
:
"""Hard-coded random signs for Hadamard transform.
https://xkcd.com/221/
"""
return
torch
.
tensor
(
[
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
-
1
],
dtype
=
torch
.
float32
,
)
def
get_hadamard_matrix
(
hadamard_dimension
:
int
)
->
torch
.
Tensor
:
"""Construct a 16x16 Hadamard matrix."""
assert
hadamard_dimension
==
16
,
"Only hadamard dimension 16 is supported."
hadamard_scale
=
1
/
math
.
sqrt
(
hadamard_dimension
)
return
(
torch
.
tensor
(
[
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
],
[
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
],
[
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
],
[
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
],
[
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
],
[
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
],
[
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
],
[
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
],
[
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
1
,
1
],
[
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
],
[
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
1
,
1
,
-
1
,
-
1
],
[
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
],
],
dtype
=
torch
.
float32
,
)
*
hadamard_scale
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
get_rht_matrix
(
with_random_sign_mask
:
bool
)
->
torch
.
Tensor
:
"""Construct matrix used in random Hadamard transform."""
hadamard_dimension
=
16
if
with_random_sign_mask
:
signs
=
get_wgrad_sign_vector
()
else
:
signs
=
get_no_random_sign_vector
()
sign_matrix
=
signs
*
torch
.
eye
(
hadamard_dimension
,
dtype
=
torch
.
float32
)
rht_matrix
=
sign_matrix
@
get_hadamard_matrix
(
hadamard_dimension
)
return
rht_matrix
.
to
(
dtype
=
torch
.
bfloat16
).
cuda
()
@
functools
.
lru_cache
(
maxsize
=
None
)
def
get_random_sign_mask_for_rht
(
with_random_sign_mask
:
bool
)
->
int
:
"""Sign mask for random Hadamard transform."""
if
with_random_sign_mask
:
return
get_sign_from_vector
(
get_wgrad_sign_vector
())
return
0
class
NVFP4Quantizer
(
Quantizer
):
"""Builder class for NVFP4 tensors with NV block scaling"""
dtype
:
TE_DType
"""Random Hadamard Transform"""
with_rht
:
bool
with_post_rht_amax
:
bool
"""amax reduction options"""
with_amax_reduction
:
bool
amax_reduction_group
:
Optional
[
dist_group_type
]
"""2D block scaling, only applicable for weights."""
with_2d_quantization
:
bool
"""Stochastic rounding, only applicable for gradients."""
stochastic_rounding
:
bool
"""RHT matrix random sign mask"""
rht_matrix_random_sign_mask_t
:
int
rht_matrix
:
torch
.
Tensor
def
__init__
(
self
,
fp4_dtype
:
TE_DType
=
tex
.
DType
.
kFloat4E2M1
,
rowwise
:
bool
=
True
,
columnwise
:
bool
=
True
,
with_amax_reduction
:
bool
=
False
,
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
with_rht
:
bool
=
False
,
with_post_rht_amax
:
bool
=
False
,
with_2d_quantization
:
bool
=
False
,
stochastic_rounding
:
bool
=
False
,
with_random_sign_mask
:
bool
=
True
,
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
fp4_dtype
self
.
with_rht
=
with_rht
self
.
with_post_rht_amax
=
with_post_rht_amax
self
.
with_amax_reduction
=
with_amax_reduction
self
.
amax_reduction_group
=
amax_reduction_group
self
.
with_2d_quantization
=
with_2d_quantization
self
.
stochastic_rounding
=
stochastic_rounding
self
.
rht_matrix_random_sign_mask_t
=
get_random_sign_mask_for_rht
(
with_random_sign_mask
)
self
.
rht_matrix
=
get_rht_matrix
(
with_random_sign_mask
)
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
dst
:
QuantizedTensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
QuantizedTensor
:
assert
isinstance
(
dst
,
NVFP4Tensor
),
f
"Cannot store quantized NVFP4 in
{
type
(
dst
)
}
type."
# Make sure input is in expected format
if
not
devices_match
(
src
.
device
,
dst
.
device
):
src
=
src
.
to
(
device
=
dst
.
device
)
if
not
src
.
is_contiguous
():
src
=
src
.
contiguous
()
# Launch cast kernel
tex
.
quantize
(
src
,
self
,
dst
,
noop_flag
)
return
dst
def
is_quantizable
(
self
,
inp
:
torch
.
Tensor
)
->
bool
:
"""Returns whether or not given inp can be quantized"""
if
inp
.
ndim
<
2
:
return
False
if
inp
.
shape
[
-
1
]
%
NVFP4_BLOCK_SCALING_SIZE
!=
0
:
return
False
if
math
.
prod
(
inp
.
shape
[:
-
1
])
%
NVFP4_BLOCK_SCALING_SIZE
!=
0
:
return
False
return
True
def
get_scale_shape
(
self
,
shape
:
Iterable
[
int
],
columnwise
:
bool
)
->
Tuple
[
int
,
int
]:
"""Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For NVFP4 1D blockwise quantization, blocksize is 16
- If columnwise: (round_to_multiple(K, 128), round_to_multiple(roundup(M / 16), 4))
- If rowwise: (round_to_multiple(M, 128), round_to_multiple(roundup(K / 16), 4))
Swizzle kernel will be performed before GEMM to suit the need of CuBLAS.
CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
"""
M
,
K
=
1
,
1
M
=
math
.
prod
(
shape
[:
-
1
])
K
=
shape
[
-
1
]
if
columnwise
:
outer
=
round_up_to_nearest_multiple
(
K
,
128
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
M
/
NVFP4_BLOCK_SCALING_SIZE
),
4
)
return
(
outer
,
inner
)
# rowwise
outer
=
round_up_to_nearest_multiple
(
M
,
128
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
K
/
NVFP4_BLOCK_SCALING_SIZE
),
4
)
return
(
outer
,
inner
)
@
staticmethod
def
get_columnwise_shape
(
shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
"""Calculate the shape of a tensor after columnwise quantization.
For NVFP4 columnwise quantization, it's performing 16x1 quantization block scaling.
Parameters
----------
shape : Iterable[int]
Original shape of the tensor
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
"""
if
len
(
shape
)
==
0
:
return
tuple
()
# and then after AG, a reorganize kernel will be called to restore the shape
colwise_shape
=
[
shape
[
-
1
]]
for
i
in
range
(
len
(
shape
)
-
1
):
colwise_shape
.
append
(
shape
[
i
])
return
tuple
(
colwise_shape
)
@
staticmethod
def
convert_shape_for_fp4
(
shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
"""Convert shape for FP4 data by dividing the last dimension by 2"""
shape
=
list
(
shape
)
shape
[
-
1
]
=
shape
[
-
1
]
//
2
return
tuple
(
shape
)
def
make_empty
(
self
,
shape
:
Iterable
[
int
],
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
False
,
)
->
NVFP4Tensor
:
# Canonicalize tensor attributes
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
assert
shape
[
-
1
]
%
NVFP4_BLOCK_SCALING_SIZE
==
0
,
(
f
"Incorrect shape
{
shape
}
for NVFP4. Tensor dims must be divisible by"
f
"
{
NVFP4_BLOCK_SCALING_SIZE
}
"
)
flat_first_dim
=
math
.
prod
(
shape
[:
-
1
])
assert
flat_first_dim
%
NVFP4_BLOCK_SCALING_SIZE
==
0
,
(
f
"Incorrect shape
{
shape
}
for NVFP4. Tensor dims must be divisible by"
f
"
{
NVFP4_BLOCK_SCALING_SIZE
}
"
)
# Allocate FP4 data
data
=
None
scale_inv
=
None
amax_rowwise
=
None
if
self
.
rowwise_usage
:
data
=
torch
.
empty
(
self
.
convert_shape_for_fp4
(
shape
),
dtype
=
torch
.
uint8
,
device
=
device
)
scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
False
)
scale_inv
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Allocate per tensor scale inverse. FP32 format.
amax_rowwise
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# Allocate FP8 data transpose if needed
columnwise_data
=
None
columnwise_scale_inv
=
None
amax_columnwise
=
None
if
self
.
columnwise_usage
:
# enforce 2D shape to avoid [S, B, H] shape and B and be 1
# and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
shape_2d
=
tuple
([
flat_first_dim
,
shape
[
-
1
]])
columnwise_data
=
torch
.
empty
(
self
.
convert_shape_for_fp4
(
self
.
get_columnwise_shape
(
shape_2d
)),
dtype
=
torch
.
uint8
,
device
=
device
,
)
columnwise_scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
True
)
columnwise_scale_inv
=
torch
.
empty
(
columnwise_scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
amax_columnwise
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# Construct FP8 tensor
return
NVFP4Tensor
(
shape
=
shape
,
dtype
=
dtype
,
rowwise_data
=
data
,
rowwise_scale_inv
=
scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
amax_rowwise
=
amax_rowwise
,
amax_columnwise
=
amax_columnwise
,
fp4_dtype
=
self
.
dtype
,
quantizer
=
self
,
requires_grad
=
requires_grad
,
)
def
calibrate
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
pass
# Calibration is no-op
def
_canonicalized_amax_reduction_group
(
self
)
->
dist_group_type
:
"""Get process group for amax reduction"""
return
canonicalize_process_group
(
self
.
amax_reduction_group
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
NVFP4BlockScaling
class
NVFP4Tensor
(
NVFP4TensorBase
,
QuantizedTensor
):
"""Quantized tensor class with FP4 data
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP4. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
rowwise_data: torch.Tensor
Raw FP4 data in a uint8 tensor (rowwise layout).
rowwise_scale_inv: torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP4, i.e. the scaling factor that must
be applied when casting from FP4 to higher
precision (rowwise).
columnwise_data: torch.Tensor, optional
Raw FP4 data in a uint8 tensor (columnwise layout).
columnwise_scale_inv: torch.Tensor, optional
Reciprocal of the scaling factor for columnwise FP4 data.
amax_rowwise: torch.Tensor, optional
Rowwise amax tracking tensor.
amax_columnwise: torch.Tensor, optional
Columnwise amax tracking tensor.
fp4_dtype: TE_DType
The FP4 data type used for quantization.
quantizer: Quantizer
The quantizer instance used for this tensor.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype, used in dequantize.
"""
# NOTE: We reorder the *args so that we can instantiate a NVFP4TensorBase with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def
__new__
(
cls
,
*
args
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
rowwise_scale_inv
:
Optional
[
torch
.
Tensor
],
columnwise_data
:
Optional
[
torch
.
Tensor
],
columnwise_scale_inv
:
Optional
[
torch
.
Tensor
],
amax_rowwise
:
Optional
[
torch
.
Tensor
],
amax_columnwise
:
Optional
[
torch
.
Tensor
],
fp4_dtype
:
TE_DType
,
quantizer
:
Quantizer
,
**
kwargs
,
):
instance
=
super
().
__new__
(
cls
,
rowwise_data
,
rowwise_scale_inv
,
columnwise_data
,
columnwise_scale_inv
,
amax_rowwise
,
amax_columnwise
,
fp4_dtype
,
quantizer
,
*
args
,
**
kwargs
,
)
return
instance
def
__repr__
(
self
,
*
,
tensor_contents
=
None
):
return
f
"NVFP4Tensor, data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
)"
def
dequantize
(
self
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
Tensor
:
"""
Construct plain PyTorch tensor from NVFP4Tensor
By default the resulting tensor's dtype is the
NVFP4Tensor's nominal dtype.
"""
# Convert PyTorch dtype to TE dtype
if
dtype
is
None
:
dtype
=
self
.
dtype
if
torch
.
is_grad_enabled
():
return
_FromNVFP4Func
.
apply
(
self
,
dtype
)
return
_FromNVFP4Func
.
forward
(
None
,
self
,
dtype
)
def
_get_quantizer
(
self
)
->
Quantizer
:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if
self
.
_quantizer
is
not
None
:
return
self
.
_quantizer
return
NVFP4Quantizer
()
def
quantize_
(
self
,
tensor
:
torch
.
Tensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
NVFP4Tensor
:
"""Update FP8 data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if
isinstance
(
tensor
,
QuantizedTensor
):
return
self
.
quantize_
(
tensor
.
dequantize
())
self
.
_get_quantizer
().
update_quantized
(
tensor
,
self
,
noop_flag
=
noop_flag
)
return
self
def
detach
(
self
)
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
# TODO(ksivamani): Fix the detach bug
return
NVFP4Tensor
.
make_like
(
self
)
def
clone
(
self
)
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
assert
self
.
_rowwise_data
is
not
None
rowwise_data
=
self
.
_rowwise_data
.
detach
().
clone
()
columnwise_data
=
None
if
self
.
_columnwise_data
is
not
None
:
columnwise_data
=
self
.
_columnwise_data
.
detach
().
clone
()
return
_IdentityFunc
.
apply
(
self
,
{
"rowwise_data"
:
rowwise_data
,
"columnwise_data"
:
columnwise_data
,
},
)
def
view
(
self
,
*
shape
:
Tuple
[
int
])
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
return
_ViewFunc
.
apply
(
self
,
shape
)
def
reshape
(
self
,
*
shape
:
Tuple
[
int
])
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
return
_ReshapeFunc
.
apply
(
self
,
shape
)
def
contiguous
(
self
,
memory_format
:
torch
.
memory_format
=
torch
.
contiguous_format
,
)
->
NVFP4Tensor
:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_data
.
is_contiguous
(
memory_format
=
memory_format
):
return
self
if
self
.
_columnwise_data
is
not
None
and
self
.
_columnwise_data
.
is_contiguous
(
memory_format
=
memory_format
):
return
self
raise
ValueError
(
"NVFP4Tensor does not support different memory formats!"
)
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
# View op
if
func
==
aten
.
view
.
default
:
if
len
(
args
)
!=
2
:
raise
RuntimeError
(
"Unexpected args for view op (expected 2 args, got {len(args)})"
)
tensor
=
args
[
0
]
shape
=
args
[
1
]
if
shape
==
list
(
tensor
.
size
()):
return
tensor
.
detach
()
return
tensor
.
view
(
shape
)
# NVFP4 dequantize not supported. Add manual support for needed funcs.
if
func
in
(
aten
.
empty_like
.
default
,
aten
.
zero_
.
default
):
tensor
=
args
[
0
]
data_init_func
=
torch
.
zeros_like
if
func
==
aten
.
zero_
.
default
else
torch
.
empty_like
scale_inv_init_func
=
(
torch
.
ones_like
if
func
==
aten
.
zero_
.
default
else
torch
.
empty_like
)
if
tensor
.
_rowwise_data
is
not
None
:
rowwise_data
=
data_init_func
(
tensor
.
_rowwise_data
)
rowwise_scale_inv
=
scale_inv_init_func
(
tensor
.
_rowwise_scale_inv
)
amax_rowwise
=
torch
.
zeros_like
(
tensor
.
_amax_rowwise
)
else
:
rowwise_data
,
rowwise_scale_inv
,
amax_rowwise
=
None
,
None
,
None
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_data
=
data_init_func
(
tensor
.
_columnwise_data
)
columnwise_scale_inv
=
scale_inv_init_func
(
tensor
.
_columnwise_scale_inv
)
amax_columnwise
=
torch
.
zeros_like
(
tensor
.
_amax_columnwise
)
else
:
columnwise_data
,
columnwise_scale_inv
,
amax_columnwise
=
(
None
,
None
,
None
,
)
return
NVFP4Tensor
(
shape
=
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
fp4_dtype
=
tensor
.
_fp4_dtype
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
amax_rowwise
=
amax_rowwise
,
amax_columnwise
=
amax_columnwise
,
quantizer
=
tensor
.
_quantizer
,
requires_grad
=
tensor
.
requires_grad
,
)
# Default case
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
@
classmethod
def
_make_in_reduce_ex
(
cls
,
shape
:
torch
.
Size
,
rowwise_data
:
torch
.
Tensor
,
rowwise_scale_inv
:
torch
.
Tensor
,
columnwise_data
:
torch
.
Tensor
,
columnwise_scale_inv
:
torch
.
Tensor
,
amax_rowwise
:
torch
.
Tensor
,
amax_columnwise
:
torch
.
Tensor
,
fp4_dtype
:
TE_DType
,
dtype
:
torch
.
dtype
,
quantizer
:
Quantizer
,
)
->
NVFP4Tensor
:
"""Build NVFP4Tensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return
NVFP4Tensor
(
shape
=
shape
,
dtype
=
dtype
,
fp4_dtype
=
fp4_dtype
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
amax_rowwise
=
amax_rowwise
,
amax_columnwise
=
amax_columnwise
,
quantizer
=
quantizer
,
requires_grad
=
False
,
)
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
"""Custom pickling"""
return
(
NVFP4Tensor
.
_make_in_reduce_ex
,
(
self
.
shape
,
self
.
_rowwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_data
,
self
.
_columnwise_scale_inv
,
self
.
_amax_rowwise
,
self
.
_amax_columnwise
,
self
.
_fp4_dtype
,
self
.
dtype
,
self
.
_quantizer
,
),
)
def
_get_data
(
self
)
->
NVFP4Tensor
:
"""Get tensor data property"""
return
super
().
data
@
torch
.
no_grad
()
def
_set_data
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""Set tensor data property
Just takes FP8 data if setting from a NVFP4Tensor. Otherwise
casts to FP8.
"""
# Tensor device
new_device
=
tensor
.
device
if
tensor
.
is_cuda
else
self
.
device
if
not
devices_match
(
new_device
,
tensor
.
device
):
tensor
=
tensor
.
to
(
device
=
new_device
)
# Just copy FP8 data if other tensor is NVFP4Tensor
if
isinstance
(
tensor
,
NVFP4Tensor
):
if
(
# pylint: disable=too-many-boolean-expressions
self
.
size
()
!=
tensor
.
size
()
or
self
.
stride
()
!=
tensor
.
stride
()
or
self
.
storage_offset
()
!=
tensor
.
storage_offset
()
or
self
.
dtype
!=
tensor
.
dtype
or
self
.
layout
!=
tensor
.
layout
or
not
devices_match
(
self
.
device
,
new_device
)
):
dummy_tensor
=
torch
.
Tensor
.
_make_wrapper_subclass
(
NVFP4Tensor
,
tensor
.
size
(),
strides
=
tensor
.
stride
(),
storage_offset
=
tensor
.
storage_offset
(),
dtype
=
tensor
.
dtype
,
layout
=
tensor
.
layout
,
requires_grad
=
tensor
.
requires_grad
,
device
=
new_device
,
)
# pylint: disable=unnecessary-dunder-call
super
(
NVFP4Tensor
,
type
(
self
)).
data
.
__set__
(
self
,
dummy_tensor
)
self
.
_rowwise_data
=
tensor
.
_rowwise_data
self
.
_columnwise_data
=
tensor
.
_columnwise_data
self
.
_quantizer
=
tensor
.
_quantizer
self
.
_rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
self
.
_columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
self
.
_amax_rowwise
=
tensor
.
_amax_rowwise
self
.
_amax_columnwise
=
tensor
.
_amax_columnwise
return
# Quantize to FP8
assert
self
.
_quantizer
is
not
None
,
"Can't quantize without a quantizer"
self
.
_quantizer
.
update_quantized
(
tensor
,
self
)
if
self
.
requires_grad
!=
tensor
.
requires_grad
:
self
.
requires_grad_
(
requires_grad
=
tensor
.
requires_grad
)
# Cast to FP8 when setting NVFP4Tensor.data
data
=
property
(
_get_data
,
_set_data
)
class
_ViewFunc
(
torch
.
autograd
.
Function
):
"""View function
View the NVFP4Tensor using the provided shape.
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
NVFP4Tensor
,
shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
cur_shape
=
tensor
.
shape
if
ctx
is
not
None
:
ctx
.
shape
=
cur_shape
if
shape
is
None
:
return
tensor
# Canonicalize shape
if
not
isinstance
(
shape
,
Iterable
):
shape
=
[
shape
]
elif
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
Iterable
):
shape
=
shape
[
0
]
if
-
1
in
shape
:
shape
=
list
(
shape
)
d_inferred
=
-
math
.
prod
(
cur_shape
)
//
math
.
prod
(
shape
)
for
i
,
d
in
enumerate
(
shape
):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
shape
[
-
1
]
!=
cur_shape
[
-
1
]:
raise
RuntimeError
(
"NVFP4Tensor does not support reshaping inner dimension "
f
"(attempted to reshape dims=
{
tuple
(
tensor
.
shape
)
}
to
{
tuple
(
shape
)
}
)"
)
# Reshape data
new_rowwise_data
=
None
new_columnwise_data
=
None
if
tensor
.
_rowwise_data
is
not
None
:
if
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
list
(
shape
[:
-
1
])
+
[
shape
[
-
1
]
//
2
]
new_rowwise_data
=
tensor
.
_rowwise_data
.
view
(
byte_shape
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
shape
[
-
1
],
math
.
prod
(
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
tensor
.
_columnwise_data
.
view
(
byte_shape
)
# Construct tensor
return
NVFP4Tensor
(
shape
,
tensor
.
dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
,
amax_rowwise
=
tensor
.
_amax_rowwise
,
amax_columnwise
=
tensor
.
_amax_columnwise
,
quantizer
=
tensor
.
_quantizer
,
fp4_dtype
=
tensor
.
_fp4_dtype
,
requires_grad
=
tensor
.
requires_grad
,
)
@
staticmethod
def
backward
(
ctx
,
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
NVFP4Tensor
):
new_rowwise_data
=
None
new_columnwise_data
=
None
if
grad
.
_rowwise_data
is
not
None
:
if
ctx
.
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
ctx
.
shape
}
as byte array."
)
byte_shape
=
list
(
ctx
.
shape
[:
-
1
])
+
[
ctx
.
shape
[
-
1
]
//
2
]
new_rowwise_data
=
grad
.
_rowwise_data
.
view
(
byte_shape
)
if
grad
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
ctx
.
shape
[
-
1
],
math
.
prod
(
ctx
.
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
ctx
.
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
grad
.
_columnwise_data
.
view
(
byte_shape
)
dgrad
=
NVFP4Tensor
(
ctx
.
shape
,
grad
.
dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
grad
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
grad
.
_columnwise_scale_inv
,
amax_rowwise
=
grad
.
_amax_rowwise
,
amax_columnwise
=
grad
.
_amax_columnwise
,
quantizer
=
grad
.
_quantizer
,
fp4_dtype
=
grad
.
_fp4_dtype
,
requires_grad
=
grad
.
requires_grad
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
class
_ReshapeFunc
(
torch
.
autograd
.
Function
):
"""Reshape function
Reshape the NVFP4Tensor using the provided shape.
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
NVFP4Tensor
,
shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
cur_shape
=
tensor
.
shape
if
ctx
is
not
None
:
ctx
.
shape
=
cur_shape
if
shape
is
None
:
return
tensor
# Canonicalize shape
if
not
isinstance
(
shape
,
Iterable
):
shape
=
[
shape
]
elif
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
Iterable
):
shape
=
shape
[
0
]
if
-
1
in
shape
:
shape
=
list
(
shape
)
d_inferred
=
-
math
.
prod
(
cur_shape
)
//
math
.
prod
(
shape
)
for
i
,
d
in
enumerate
(
shape
):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
shape
[
-
1
]
!=
cur_shape
[
-
1
]:
raise
RuntimeError
(
"NVFP4Tensor does not support reshaping inner dimension "
f
"(attempted to reshape dims=
{
tuple
(
tensor
.
shape
)
}
to
{
tuple
(
shape
)
}
)"
)
# Reshape data
new_rowwise_data
=
None
new_columnwise_data
=
None
if
tensor
.
_rowwise_data
is
not
None
:
if
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
list
(
shape
[:
-
1
])
+
[
shape
[
-
1
]
//
2
]
new_rowwise_data
=
tensor
.
_rowwise_data
.
reshape
(
byte_shape
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
shape
[
-
1
],
math
.
prod
(
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
tensor
.
_columnwise_data
.
reshape
(
byte_shape
)
# Construct tensor
return
NVFP4Tensor
(
shape
,
tensor
.
dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
,
amax_rowwise
=
tensor
.
_amax_rowwise
,
amax_columnwise
=
tensor
.
_amax_columnwise
,
quantizer
=
tensor
.
_quantizer
,
fp4_dtype
=
tensor
.
_fp4_dtype
,
requires_grad
=
tensor
.
requires_grad
,
)
@
staticmethod
def
backward
(
ctx
,
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
NVFP4Tensor
):
new_rowwise_data
=
None
new_columnwise_data
=
None
if
grad
.
_rowwise_data
is
not
None
:
if
ctx
.
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
ctx
.
shape
}
as byte array."
)
byte_shape
=
list
(
ctx
.
shape
[:
-
1
])
+
[
ctx
.
shape
[
-
1
]
//
2
]
new_rowwise_data
=
grad
.
_rowwise_data
.
reshape
(
byte_shape
)
if
grad
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
ctx
.
shape
[
-
1
],
math
.
prod
(
ctx
.
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
ctx
.
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
grad
.
_columnwise_data
.
reshape
(
byte_shape
)
dgrad
=
NVFP4Tensor
(
ctx
.
shape
,
grad
.
dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
grad
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
grad
.
_columnwise_scale_inv
,
amax_rowwise
=
grad
.
_amax_rowwise
,
amax_columnwise
=
grad
.
_amax_columnwise
,
quantizer
=
grad
.
_quantizer
,
fp4_dtype
=
grad
.
_fp4_dtype
,
requires_grad
=
grad
.
requires_grad
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
transformer_engine/pytorch/tensor/quantized_tensor.py
View file @
53fa872c
...
...
@@ -264,6 +264,10 @@ class Quantizer(abc.ABC):
"""Returns True if the quantizer supports only rowwise all-gather"""
return
False
def
is_quantizable
(
self
,
inp
:
torch
.
Tensor
)
->
bool
:
# pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized"""
return
True
class
_QuantizeFunc
(
torch
.
autograd
.
Function
):
"""Cast to FP8 from other dtype"""
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
53fa872c
...
...
@@ -4,12 +4,14 @@
"""Helper functions for using fp8 tensors as weights"""
import
os
from
typing
import
Optional
,
Union
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
multi_tensor_scale
,
multi_tensor_compute_scale_and_scale_inv
from
.quantized_tensor
import
QuantizedTensor
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
QuantizedTensorBase
from
.float8_tensor
import
Float8Tensor
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8Quantizer
from
.float8_blockwise_tensor
import
Float8BlockwiseQTensor
,
Float8BlockQuantizer
...
...
@@ -455,3 +457,20 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
tex
.
fp8_block_scaling_partial_cast
(
master_weight
,
model_weight_fragment
,
scale
,
h
,
w
,
start_offset
,
block_len
,
fp8_dtype
)
def
is_experimental
(
x
:
Optional
[
Union
[
Quantizer
,
QuantizedTensorBase
]]
=
None
)
->
bool
:
"""Check if an environment or object is using experimental Kitchen middleware.
Returns False if x is a torch.Tensor.
"""
# Detect if the environment is experimental
if
x
is
None
:
return
int
(
os
.
getenv
(
"QAT_PARAMS"
,
"0"
))
>
0
# Detect if the object is experimental
if
isinstance
(
x
,
torch
.
Tensor
):
return
False
if
not
isinstance
(
x
,
(
Quantizer
,
QuantizedTensorBase
)):
raise
AssertionError
(
"Object must be a Quantizer or QuantizedTensorBase instance"
)
return
hasattr
(
x
,
"experimental"
)
and
x
.
experimental
transformer_engine/pytorch/transformer.py
View file @
53fa872c
...
...
@@ -191,6 +191,17 @@ class TransformerLayer(torch.nn.Module):
and `DotProductAttention` modules.
name: str, default = `None`
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
...
...
@@ -306,6 +317,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type
:
Optional
[
str
]
=
None
,
qk_norm_eps
:
float
=
1e-6
,
qk_norm_before_rope
:
bool
=
False
,
softmax_type
:
str
=
"vanilla"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -362,6 +374,7 @@ class TransformerLayer(torch.nn.Module):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
attn_input_format
=
attn_input_format
self
.
softmax_type
=
softmax_type
self
.
name
=
name
...
...
@@ -397,6 +410,7 @@ class TransformerLayer(torch.nn.Module):
"qkv_format"
:
self
.
attn_input_format
,
"seq_length"
:
seq_length
,
"micro_batch_size"
:
micro_batch_size
,
"softmax_type"
:
self
.
softmax_type
,
}
self
.
self_attention
=
MultiheadAttention
(
...
...
transformer_engine/pytorch/triton/pad.py
0 → 100644
View file @
53fa872c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""NVFP4 padding kernels
TODO(ksivamani): Documentation
"""
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"out_dim0"
,
"out_dim1"
],
)
@
triton
.
jit
def
zero_pad_kernel
(
inp_ptr
,
out_ptr
,
in_dim0
:
tl
.
constexpr
,
in_dim1
:
tl
.
constexpr
,
out_dim0
:
tl
.
constexpr
,
out_dim1
:
tl
.
constexpr
,
in_s0
,
in_s1
,
out_s0
,
out_s1
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""Pads a tensor assuming it's a columnwise scaling inverse."""
# tile over OUTPUT coordinates
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# output rows
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# output cols
om
=
offs_m
[:,
None
]
on
=
offs_n
[
None
,
:]
# edge masking for output
out_mask
=
(
om
<
out_dim0
)
&
(
on
<
out_dim1
)
# valid input region is simply top-left (no offsets)
in_mask
=
(
om
<
in_dim0
)
&
(
on
<
in_dim1
)
# load valid input, else zero (masked load touches memory only where True)
x
=
tl
.
load
(
inp_ptr
+
om
*
in_s0
+
on
*
in_s1
,
mask
=
in_mask
,
other
=
0
)
# store to output (only within bounds of the output tile)
tl
.
store
(
out_ptr
+
om
*
out_s0
+
on
*
out_s1
,
x
,
mask
=
out_mask
)
def
pad_columnwise_scale_inv
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Pads a tensor assuming it's a columnwise scaling inverse."""
assert
inp
.
ndim
==
2
dim0
,
dim1
=
inp
.
shape
pad_x
=
(
128
-
dim0
%
128
)
%
128
pad_y
=
(
4
-
dim1
%
4
)
%
4
out_x
=
dim0
+
pad_x
out_y
=
dim1
+
pad_y
out
=
torch
.
empty
((
out_x
,
out_y
),
device
=
inp
.
device
,
dtype
=
inp
.
dtype
)
in_s0
,
in_s1
=
inp
.
stride
()
out_s0
,
out_s1
=
out
.
stride
()
BLOCK_M
,
BLOCK_N
=
128
,
128
grid
=
(
triton
.
cdiv
(
out_x
,
BLOCK_M
),
triton
.
cdiv
(
out_y
,
BLOCK_N
))
zero_pad_kernel
[
grid
](
inp
,
out
,
dim0
,
dim1
,
out_x
,
out_y
,
in_s0
,
in_s1
,
out_s0
,
out_s1
,
)
return
out
transformer_engine/pytorch/utils.py
View file @
53fa872c
...
...
@@ -10,8 +10,9 @@ import os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
.
import
torch_version
from
.tensor.quantized_tensor
import
Quantizer
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
requires_grad
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
...
...
@@ -455,6 +456,16 @@ if IS_HIP_EXTENSION:
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
assert_dim_for_all_gather
(
tensor
:
torch
.
Tensor
,
with_all_gather
:
bool
,
quantizer
:
Quantizer
)
->
None
:
"""Assert that tensor dimensions are supported for all-gather"""
if
with_all_gather
:
assert
quantizer
.
is_quantizable
(
tensor
),
(
"All-gather requires quantizable tensor for quantizer "
+
quantizer
.
__class__
.
__name__
)
def
is_bf16_compatible
()
->
None
:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
...
...
@@ -486,6 +497,8 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
@
functools
.
lru_cache
(
maxsize
=
None
)
def
get_cudnn_version
()
->
Tuple
[
int
,
int
,
int
]:
"""Runtime cuDNN version (major, minor, patch)"""
import
transformer_engine.pytorch.cpp_extensions
as
ext
# ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out
if
IS_HIP_EXTENSION
:
return
(
99
,
0
,
0
)
...
...
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment