Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
3076 additions
and
133 deletions
+3076
-133
transformer_engine/pytorch/quantization.py
transformer_engine/pytorch/quantization.py
+1415
-0
transformer_engine/pytorch/setup.py
transformer_engine/pytorch/setup.py
+2
-1
transformer_engine/pytorch/tensor/__init__.py
transformer_engine/pytorch/tensor/__init__.py
+37
-12
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+41
-15
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+22
-20
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+19
-21
transformer_engine/pytorch/tensor/nvfp4_tensor.py
transformer_engine/pytorch/tensor/nvfp4_tensor.py
+902
-0
transformer_engine/pytorch/tensor/quantized_tensor.py
transformer_engine/pytorch/tensor/quantized_tensor.py
+70
-24
transformer_engine/pytorch/tensor/storage/__init__.py
transformer_engine/pytorch/tensor/storage/__init__.py
+9
-0
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py
...pytorch/tensor/storage/float8_blockwise_tensor_storage.py
+5
-5
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
...er_engine/pytorch/tensor/storage/float8_tensor_storage.py
+7
-7
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
...mer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
+7
-7
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
...mer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
+348
-0
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+20
-1
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+14
-0
transformer_engine/pytorch/triton/pad.py
transformer_engine/pytorch/triton/pad.py
+94
-0
transformer_engine/pytorch/triton/permutation.py
transformer_engine/pytorch/triton/permutation.py
+14
-10
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+50
-10
No files found.
transformer_engine/pytorch/quantization.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Quantization utilities for TransformerEngine"""
from
__future__
import
annotations
import
abc
import
itertools
import
functools
import
warnings
import
os
from
contextlib
import
contextmanager
from
collections
import
deque
from
typing
import
Callable
,
List
,
Optional
,
Dict
,
Any
,
Tuple
,
Union
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
(
Recipe
,
DelayedScaling
,
Format
,
MXFP8BlockScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
NVFP4BlockScaling
,
CustomRecipe
,
)
from
.constants
import
dist_group_type
from
.utils
import
get_device_compute_capability
from
.jit
import
jit_fuser
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
blockwise_fp8_block_len
=
int
(
os
.
getenv
(
"NVTE_BLOCKWISE_FP8_BLOCK_LEN"
,
"128"
))
__all__
=
[
"autocast"
,
"quantized_model_init"
,
"is_fp8_available"
,
"is_mxfp8_available"
,
"is_fp8_block_scaling_available"
,
"is_nvfp4_available"
,
"get_default_recipe"
,
]
if
IS_HIP_EXTENSION
:
from
transformer_engine.pytorch.utils
import
is_K100_AI
,
is_BW
@
functools
.
lru_cache
(
maxsize
=
None
)
def
check_fp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
if
IS_HIP_EXTENSION
:
if
(
is_K100_AI
()
or
is_BW
())
and
int8_simulation_fp8
:
return
True
,
"DCU turn on fp8 simulation with int8"
else
:
return
False
,
"DCU not support fp8 for now"
else
:
if
get_device_compute_capability
()
>=
(
9
,
0
):
# hopper and above
return
True
,
""
if
get_device_compute_capability
()
<
(
8
,
9
):
# pre-ada
return
False
,
"Device compute capability 8.9 or higher required for FP8 execution."
if
tex
.
get_cublasLt_version
()
<
120103
:
return
False
,
"CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if
float
(
torch
.
version
.
cuda
)
<
12.1
:
return
False
,
"Cuda version 12.1 or higher required for FP8 execution on Ada."
return
True
,
""
@
functools
.
lru_cache
(
maxsize
=
None
)
def
check_mxfp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
if
get_device_compute_capability
()
>=
(
12
,
0
):
return
False
,
"MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
return
True
,
""
return
False
,
"Device compute capability 10.0 or higher required for MXFP8 execution."
@
functools
.
lru_cache
(
maxsize
=
None
)
def
check_nvfp4_support
()
->
Tuple
[
bool
,
str
]:
"""Return if nvfp4 support is available"""
if
IS_HIP_EXTENSION
:
return
False
,
"NVFP4 is not supported on rocm platform."
else
:
if
get_device_compute_capability
()
>=
(
10
,
0
):
# blackwell and above
return
True
,
""
return
False
,
"Device compute capability 10.0 or higher required for NVFP4 execution."
@
functools
.
lru_cache
(
maxsize
=
None
)
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 block scaling support is available"""
if
IS_HIP_EXTENSION
:
if
is_K100_AI
()
or
is_BW
()
and
int8_simulation_fp8
:
return
True
,
""
else
:
return
False
,
"DCU not support block_scaling fp8 for now"
if
get_device_compute_capability
()
>=
(
9
,
0
)
and
float
(
torch
.
version
.
cuda
)
>=
12.9
:
return
True
,
""
return
(
False
,
"FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9."
,
)
def
check_recipe_support
(
recipe
:
Recipe
)
->
None
:
"""Check if the given recipe is supported."""
recipe_supported
=
True
unsupported_reason
=
""
if
isinstance
(
recipe
,
(
DelayedScaling
,
Float8CurrentScaling
)):
recipe_supported
,
unsupported_reason
=
check_fp8_support
()
elif
isinstance
(
recipe
,
Float8BlockScaling
):
recipe_supported
,
unsupported_reason
=
check_fp8_block_scaling_support
()
elif
isinstance
(
recipe
,
MXFP8BlockScaling
):
recipe_supported
,
unsupported_reason
=
check_mxfp8_support
()
assert
recipe_supported
,
unsupported_reason
def
get_default_fp8_recipe
()
->
Recipe
:
"""FP8 recipe with default args."""
if
check_mxfp8_support
()[
0
]:
return
MXFP8BlockScaling
()
if
get_device_compute_capability
()
>=
(
12
,
0
):
# This is a temporary restriction until MXFP8 is supported for all gemm layouts.
return
Float8CurrentScaling
()
return
DelayedScaling
()
def
get_default_recipe
()
->
Recipe
:
"""Returns the default training recipe based on available device."""
return
get_default_fp8_recipe
()
def
get_fp8_torch_dtype
(
fp8_recipe
:
Recipe
,
fprop_tensor
:
bool
=
True
)
->
torch
.
dtype
:
"""Get fp8 data type according to recipe and tensor"""
if
fp8_recipe
.
fp8_format
==
Format
.
E4M3
or
(
fp8_recipe
.
fp8_format
==
Format
.
HYBRID
and
fprop_tensor
):
return
torch
.
float8_e4m3fn
return
torch
.
float8_e5m2
def
get_fp8_te_dtype
(
fp8_recipe
:
Recipe
,
fprop_tensor
:
bool
=
True
)
->
tex
.
DType
:
"""Get fp8 data type according to recipe and tensor"""
if
fp8_recipe
.
fp8_format
==
Format
.
E4M3
or
(
fp8_recipe
.
fp8_format
==
Format
.
HYBRID
and
fprop_tensor
):
return
tex
.
DType
.
kFloat8E4M3
return
tex
.
DType
.
kFloat8E5M2
def
get_fp4_te_dtype
(
fp4_recipe
:
Recipe
)
->
tex
.
DType
:
"""Get fp4 data type according to recipe and tensor"""
if
fp4_recipe
.
fp4_format
==
Format
.
E2M1
:
return
tex
.
DType
.
kFloat4E2M1
raise
ValueError
(
f
"Unsupported FP4 format:
{
fp4_recipe
.
fp4_format
}
"
)
def
get_fp8_max
(
fp8_recipe
:
Recipe
,
fprop_tensor
:
bool
=
True
)
->
tex
.
DType
:
"""Get max representible FP8 value."""
if
fp8_recipe
.
fp8_format
==
Format
.
E4M3
or
(
fp8_recipe
.
fp8_format
==
Format
.
HYBRID
and
fprop_tensor
):
return
Format
.
E4M3
.
value
.
max_fwd
return
Format
.
E5M2
.
value
.
max_fwd
def
is_fp8_available
(
return_reason
:
bool
=
False
)
->
Union
[
bool
,
Tuple
[
bool
,
str
]]:
"""
Determine if FP8 support is available for the delayed
scaling and per tensor current scaling recipe.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when required support is not available. The reason
will be an empty string if support for FP8 is available.
"""
if
return_reason
:
return
check_fp8_support
()
return
check_fp8_support
()[
0
]
def
is_mxfp8_available
(
return_reason
:
bool
=
False
)
->
Union
[
bool
,
Tuple
[
bool
,
str
]]:
"""
Determine if support is available for the MXFP8 recipe.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when required support is not available. The reason
will be an empty string if support for MXFP8 is available.
"""
if
return_reason
:
return
check_mxfp8_support
()
return
check_mxfp8_support
()[
0
]
def
is_fp8_block_scaling_available
(
return_reason
:
bool
=
False
)
->
Union
[
bool
,
Tuple
[
bool
,
str
]]:
"""
Determine if support is available for the FP8 block scaling recipe.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when required support is not available. The reason
will be an empty string if support for FP8 block scaling is available.
"""
if
return_reason
:
return
check_fp8_block_scaling_support
()
return
check_fp8_block_scaling_support
()[
0
]
def
is_nvfp4_available
(
return_reason
:
bool
=
False
)
->
Union
[
bool
,
Tuple
[
bool
,
str
]]:
"""
Determine if support is available for the NVFP4 recipe.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when required support is not available. The reason
will be an empty string if support for NVFP4 is available.
"""
if
return_reason
:
return
check_nvfp4_support
()
return
check_nvfp4_support
()[
0
]
class
FP8GlobalStateManager
:
"""Class to keep track of and manipulate the global
FP8 state at different stages of execution.
"""
FP8_ENABLED
=
False
FP8_CALIBRATION
=
False
FP8_RECIPE
=
None
FP8_DISTRIBUTED_GROUP
=
None
FP8_PARAMETERS
=
False
HIGH_PRECISION_INIT_VAL
=
False
IS_FIRST_FP8_MODULE
=
False
FP8_GRAPH_CAPTURING
=
False
AUTOCAST_DEPTH
=
0
global_amax_buffer
=
{}
global_amax_history_buffer
=
{}
global_scale_buffer
=
{}
fp8_tensors_recompute_buffer
=
[]
fp8_available
=
None
reason_for_no_fp8
=
""
autocast_arguments
=
{}
skip_fp8_weight_update_tensor
=
None
mxfp8_available
=
None
reason_for_no_mxfp8
=
""
fp8_block_scaling_available
=
None
reason_for_no_fp8_block_scaling
=
None
nvfp4_available
=
None
reason_for_no_nvfp4
=
""
@
classmethod
def
reset
(
cls
)
->
None
:
"""Reset the global state"""
cls
.
FP8_ENABLED
=
False
cls
.
FP8_CALIBRATION
=
False
cls
.
FP8_RECIPE
=
None
cls
.
FP8_DISTRIBUTED_GROUP
=
None
cls
.
FP8_PARAMETERS
=
False
cls
.
HIGH_PRECISION_INIT_VAL
=
False
cls
.
IS_FIRST_FP8_MODULE
=
False
cls
.
FP8_GRAPH_CAPTURING
=
False
cls
.
AUTOCAST_DEPTH
=
0
cls
.
global_amax_buffer
=
{}
cls
.
global_amax_history_buffer
=
{}
cls
.
global_scale_buffer
=
{}
cls
.
fp8_tensors_recompute_buffer
=
[]
cls
.
fp8_available
=
None
cls
.
reason_for_no_fp8
=
""
cls
.
autocast_arguments
=
{}
cls
.
skip_fp8_weight_update_tensor
=
None
cls
.
mxfp8_available
=
None
cls
.
reason_for_no_mxfp8
=
""
cls
.
fp8_block_scaling_available
=
None
cls
.
reason_for_no_fp8_block_scaling
=
""
@
classmethod
def
set_skip_fp8_weight_update_tensor
(
cls
,
skip
:
bool
)
->
None
:
"""`skip_fp8_weight_update_tensor` inplace setter."""
if
cls
.
skip_fp8_weight_update_tensor
is
None
:
cls
.
skip_fp8_weight_update_tensor
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
cls
.
skip_fp8_weight_update_tensor
.
fill_
(
skip
)
@
classmethod
def
get_skip_fp8_weight_update_tensor
(
cls
)
->
None
:
"""`skip_fp8_weight_update_tensor` getter."""
return
cls
.
skip_fp8_weight_update_tensor
@
classmethod
def
is_fp8_available
(
cls
)
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
return
check_fp8_support
()
@
classmethod
def
is_mxfp8_available
(
cls
)
->
Tuple
[
bool
,
str
]:
"""Return if MXFP8/current scaling support is available."""
return
check_mxfp8_support
()
@
classmethod
def
is_fp8_block_scaling_available
(
cls
)
->
Tuple
[
bool
,
str
]:
"""Return if Float8 block scaling support is available."""
return
check_fp8_block_scaling_support
()
@
classmethod
def
is_nvfp4_available
(
cls
)
->
Tuple
[
bool
,
str
]:
"""Return if NVFP4 support is available."""
return
check_nvfp4_support
()
@
staticmethod
def
get_meta_tensor_key
(
forward
:
bool
=
True
)
->
str
:
"""Returns scaling key in `fp8_meta`."""
if
forward
:
return
"scaling_fwd"
return
"scaling_bwd"
@
staticmethod
def
get_fwd_bwd_key
(
forward
:
bool
=
True
)
->
str
:
"""Convert bool `forward` to string."""
return
"forward"
if
forward
else
"backward"
@
classmethod
def
get_buffer_info
(
cls
)
->
str
:
"""
Returns a key for `fp8_meta` that stores the module's index
in the global buffers along with autocast information.
"""
return
"buffer_index_and_autocast_key"
@
classmethod
def
get_key_in_buffer
(
cls
,
forward
:
bool
,
fp8_recipe
:
Recipe
,
fp8_group
:
dist_group_type
,
)
->
str
:
"""Returns a key into the global FP8 buffers."""
autocast_key
=
cls
.
get_unique_autocast_key
(
fp8_recipe
,
fp8_group
)
fwd_bwd_key
=
cls
.
get_fwd_bwd_key
(
forward
)
return
f
"
{
fwd_bwd_key
}
_
{
autocast_key
}
"
@
classmethod
def
split_key_in_buffer
(
cls
,
key
:
str
)
->
Tuple
[
bool
,
str
]:
"""Splits buffer key into relevant parts."""
forward
,
autocast_key
=
key
.
split
(
"_"
,
1
)
forward
=
forward
==
"forward"
return
forward
,
autocast_key
@
classmethod
def
add_fp8_tensors_to_global_buffer
(
cls
,
fp8_meta
:
Dict
[
str
,
Any
],
)
->
None
:
"""
Delayed scaling only.
The amax reduction process happens completely outside the FP8 modules.
To participate in the reduction, the only role played by a module is
to call this function in order to append it's FP8 tensor into a global
buffer. There are 5 global buffers maintained, one each for amax, amax
history, scale, scale-inverse, and non-weight-mask. Each buffer has
keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix
to indicate the type of FP8 tensor, since the forward and backward
reductions happen separately.
Note: For CG capture, this method is called from the graphed
wrapper. For non CG case, it's called from within the module.
"""
# delayed scaling only function, noop for any other recipe
if
not
fp8_meta
[
"recipe"
].
delayed
():
return
# Every module must call this function exactly once since
# the amax tensors are static. Ensures that compatibility
# with non-graphed modules is maintained.
index_in_buffer
=
cls
.
get_buffer_info
()
# Same index for fwd/bwd fp8 tensors.
if
index_in_buffer
in
fp8_meta
:
return
fp8_meta
[
index_in_buffer
]
=
[]
for
forward
in
(
True
,
False
):
fp8_meta_tensor_key
=
cls
.
get_meta_tensor_key
(
forward
=
forward
)
if
fp8_meta_tensor_key
not
in
fp8_meta
:
# Handles non-parameter FP8 modules, e.g. DPA.
continue
key
=
cls
.
get_key_in_buffer
(
forward
,
fp8_meta
[
"recipe"
],
fp8_meta
[
"fp8_group"
])
if
key
not
in
cls
.
global_amax_buffer
:
cls
.
global_amax_buffer
[
key
]
=
[
fp8_meta
[
fp8_meta_tensor_key
].
amax_history
[
0
]]
cls
.
global_amax_history_buffer
[
key
]
=
[
fp8_meta
[
fp8_meta_tensor_key
].
amax_history
]
cls
.
global_scale_buffer
[
key
]
=
[
fp8_meta
[
fp8_meta_tensor_key
].
scale
]
else
:
cls
.
global_amax_buffer
[
key
].
append
(
fp8_meta
[
fp8_meta_tensor_key
].
amax_history
[
0
])
cls
.
global_amax_history_buffer
[
key
].
append
(
fp8_meta
[
fp8_meta_tensor_key
].
amax_history
)
cls
.
global_scale_buffer
[
key
].
append
(
fp8_meta
[
fp8_meta_tensor_key
].
scale
)
fp8_meta
[
index_in_buffer
].
append
(
len
(
cls
.
global_amax_buffer
[
key
])
-
1
)
fp8_meta
[
index_in_buffer
].
append
(
key
)
@
classmethod
def
is_fp8_enabled
(
cls
)
->
bool
:
"""Is FP8 enabled"""
return
cls
.
FP8_ENABLED
@
classmethod
def
is_fp8_calibration
(
cls
)
->
bool
:
"""Is FP8 calibration"""
return
cls
.
FP8_CALIBRATION
@
classmethod
def
with_fp8_parameters
(
cls
)
->
bool
:
"""Should the parameters be stored as FP8"""
return
cls
.
FP8_PARAMETERS
@
classmethod
def
with_high_precision_init_val
(
cls
)
->
bool
:
"""Should the high precision initial values be stored with FP8 parameters"""
return
cls
.
HIGH_PRECISION_INIT_VAL
@
classmethod
def
fp8_graph_capturing
(
cls
)
->
bool
:
"""Is CUDA graph capture under way?"""
return
cls
.
FP8_GRAPH_CAPTURING
or
torch
.
cuda
.
is_current_stream_capturing
()
@
classmethod
def
is_first_fp8_module
(
cls
):
"""Returns `True` only the first time when called multiple
times from within the same `autocast` context.
"""
tmp
=
cls
.
IS_FIRST_FP8_MODULE
cls
.
IS_FIRST_FP8_MODULE
=
False
return
tmp
@
classmethod
def
get_fp8_recipe
(
cls
)
->
Recipe
:
"""Return the fp8 recipe"""
if
cls
.
FP8_RECIPE
is
not
None
:
return
cls
.
FP8_RECIPE
return
get_default_fp8_recipe
()
@
classmethod
def
get_fp8_group
(
cls
)
->
Union
[
dist_group_type
,
None
]:
"""Return the fp8 group for scale/amax comm"""
return
cls
.
FP8_DISTRIBUTED_GROUP
@
classmethod
def
get_autocast_state
(
cls
)
->
Tuple
[
bool
,
bool
,
Recipe
,
dist_group_type
,
bool
]:
"""FP8 autocast state getter"""
return
(
cls
.
FP8_ENABLED
,
cls
.
FP8_CALIBRATION
,
cls
.
FP8_RECIPE
,
cls
.
FP8_DISTRIBUTED_GROUP
,
cls
.
IS_FIRST_FP8_MODULE
,
cls
.
FP8_GRAPH_CAPTURING
,
)
@
classmethod
def
set_autocast_state
(
cls
,
fp8_state
:
Tuple
[
bool
,
bool
,
DelayedScaling
,
dist_group_type
,
bool
]
)
->
None
:
"""FP8 autocast state setter"""
(
cls
.
FP8_ENABLED
,
cls
.
FP8_CALIBRATION
,
cls
.
FP8_RECIPE
,
cls
.
FP8_DISTRIBUTED_GROUP
,
cls
.
IS_FIRST_FP8_MODULE
,
cls
.
FP8_GRAPH_CAPTURING
,
)
=
fp8_state
@
staticmethod
def
reduce_tensor_across_group_op_max
(
tensor
:
torch
.
Tensor
,
group
:
dist_group_type
)
->
None
:
"""Reduce tensor across given group."""
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
group
,
async_op
=
False
,
)
@
classmethod
def
reduce_and_update_fp8_tensors
(
cls
,
forward
:
bool
=
True
,
)
->
None
:
"""Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer."""
# global_amax_buffer should only be non-empty for fp8 delayed scaling
for
buffer_key
,
amax_buffer
in
cls
.
global_amax_buffer
.
items
():
# Check for forward or backward reduction.
fwd_update
,
autocast_key
=
cls
.
split_key_in_buffer
(
buffer_key
)
if
fwd_update
!=
forward
:
continue
if
len
(
amax_buffer
)
==
0
:
continue
# Retrieve autocast specific args and concat amaxes.
recipe
,
group
=
cls
.
autocast_arguments
[
autocast_key
]
contiguous_amax
=
torch
.
cat
(
amax_buffer
)
# Reduction.
if
(
recipe
.
reduce_amax
and
torch
.
distributed
.
is_initialized
()
and
torch
.
distributed
.
get_world_size
(
group
=
group
)
>
1
):
cls
.
reduce_tensor_across_group_op_max
(
contiguous_amax
,
group
)
# Amax and scale update.
unfused_update
=
(
bool
(
int
(
os
.
getenv
(
"NVTE_UNFUSED_FP8_UPDATE"
,
"0"
)))
or
callable
(
recipe
.
amax_compute_algo
)
or
callable
(
recipe
.
scaling_factor_compute_algo
)
)
if
not
unfused_update
:
tex
.
fused_amax_and_scale_update_after_reduction
(
contiguous_amax
,
cls
.
global_amax_history_buffer
[
buffer_key
],
cls
.
global_scale_buffer
[
buffer_key
],
recipe
.
amax_compute_algo
,
get_fp8_te_dtype
(
recipe
,
forward
),
recipe
.
margin
,
)
else
:
split_and_copy
(
contiguous_amax
,
amax_buffer
,
[
x
.
numel
()
for
x
in
amax_buffer
])
for
amax_history
,
scale
in
zip
(
cls
.
global_amax_history_buffer
[
buffer_key
],
cls
.
global_scale_buffer
[
buffer_key
],
):
_amax_and_scale_update
(
amax_history
,
scale
,
get_fp8_max
(
recipe
,
forward
),
recipe
)
@
classmethod
def
get_unique_autocast_key
(
cls
,
recipe
:
Optional
[
Recipe
]
=
None
,
group
:
Optional
[
dist_group_type
]
=
None
,
):
"""
For FP8, each autocast can be uniquely identified by the recipe and fp8 group.
Safely using `hash` as we never cross checkpoint boundaries.
"""
return
f
"
{
str
(
recipe
)
}
:
{
hash
(
group
)
}
"
@
classmethod
def
autocast_enter
(
cls
,
enabled
:
bool
=
False
,
calibrating
:
bool
=
False
,
fp8_recipe
:
Optional
[
Recipe
]
=
None
,
fp8_group
:
Optional
[
dist_group_type
]
=
None
,
_graph
:
bool
=
False
,
)
->
None
:
"""Set state and tracking variables for entry into FP8 region."""
fp8_recipe
=
get_default_fp8_recipe
()
if
fp8_recipe
is
None
else
fp8_recipe
autocast_key
=
cls
.
get_unique_autocast_key
(
fp8_recipe
,
fp8_group
)
cls
.
autocast_arguments
[
autocast_key
]
=
(
fp8_recipe
,
fp8_group
)
cls
.
FP8_ENABLED
=
enabled
cls
.
FP8_CALIBRATION
=
calibrating
cls
.
FP8_RECIPE
=
fp8_recipe
cls
.
FP8_DISTRIBUTED_GROUP
=
fp8_group
cls
.
FP8_GRAPH_CAPTURING
=
_graph
if
cls
.
AUTOCAST_DEPTH
==
0
:
cls
.
IS_FIRST_FP8_MODULE
=
True
cls
.
AUTOCAST_DEPTH
+=
1
if
enabled
:
fp8_available
,
reason_for_no_fp8
=
cls
.
is_fp8_available
()
assert
fp8_available
,
reason_for_no_fp8
if
isinstance
(
fp8_recipe
,
MXFP8BlockScaling
):
mxfp8_available
,
reason_for_no_mxfp8
=
cls
.
is_mxfp8_available
()
assert
mxfp8_available
,
reason_for_no_mxfp8
if
isinstance
(
fp8_recipe
,
Float8BlockScaling
):
fp8_block_available
,
reason_for_no_fp8_block
=
cls
.
is_fp8_block_scaling_available
()
assert
fp8_block_available
,
reason_for_no_fp8_block
if
isinstance
(
fp8_recipe
,
NVFP4BlockScaling
):
nvfp4_available
,
reason_for_no_nvfp4
=
cls
.
is_nvfp4_available
()
assert
nvfp4_available
,
reason_for_no_nvfp4
@
classmethod
def
autocast_exit
(
cls
,
enabled
:
bool
,
_graph
:
bool
)
->
None
:
"""Set state and tracking variables for exit from FP8 region."""
cls
.
AUTOCAST_DEPTH
-=
1
# Reduce only the non-FP8 weight modules here.
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if
enabled
and
cls
.
AUTOCAST_DEPTH
==
0
and
not
_graph
and
torch
.
is_grad_enabled
():
# delayed scaling only function, for other recipes (current scaling with any granularity),
# this is noop for other recipes because cls.global_amax_buffer is empty list
cls
.
reduce_and_update_fp8_tensors
(
forward
=
True
)
@
classmethod
def
copy_forward_fp8_meta_tensors_for_recompute
(
cls
,
fp8_meta
:
Dict
[
str
,
Any
])
->
None
:
"""Copy the scaling factors and amaxes for recompute forward phase
to ensure both forward steps are numerically same.
"""
# delayed scaling only function, noop for any other recipe
if
not
fp8_meta
[
"recipe"
].
delayed
():
return
buffer_position_key
=
"global_fp8_buffer_pos_fwd_recompute"
to_copy
=
[
fp8_meta
[
"scaling_fwd"
].
amax_history
.
clone
(),
fp8_meta
[
"scaling_fwd"
].
scale
.
clone
(),
]
if
buffer_position_key
in
fp8_meta
:
cls
.
fp8_tensors_recompute_buffer
[
fp8_meta
[
buffer_position_key
]].
append
(
to_copy
)
else
:
if
len
(
cls
.
fp8_tensors_recompute_buffer
)
==
0
:
cls
.
fp8_tensors_recompute_buffer
=
[
deque
()]
else
:
cls
.
fp8_tensors_recompute_buffer
.
append
(
deque
())
cls
.
fp8_tensors_recompute_buffer
[
-
1
].
append
(
to_copy
)
fp8_meta
[
buffer_position_key
]
=
len
(
cls
.
fp8_tensors_recompute_buffer
)
-
1
@
classmethod
def
get_old_fp8_meta_tensors_for_recompute
(
cls
,
fp8_meta
:
Dict
[
str
,
Any
])
->
None
:
"""Switch to the copied scaling factors and amaxes from phase
1 forward for indentical numerical outputs.
"""
# delayed scaling only function, noop for any other recipe
if
not
fp8_meta
[
"recipe"
].
delayed
():
return
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta
[
"updated_amax_history_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
amax_history
.
clone
()
fp8_meta
[
"updated_scale_fwd"
]
=
fp8_meta
[
"scaling_fwd"
].
scale
.
clone
()
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key
=
"global_fp8_buffer_pos_fwd_recompute"
stashed_fp8_meta
=
cls
.
fp8_tensors_recompute_buffer
[
fp8_meta
[
buffer_position_key
]].
popleft
()
# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta
[
"scaling_fwd"
].
amax_history
.
copy_
(
stashed_fp8_meta
[
0
])
fp8_meta
[
"scaling_fwd"
].
scale
.
copy_
(
stashed_fp8_meta
[
1
])
@
staticmethod
def
restore_fp8_meta_tensors
(
fp8_meta
:
Dict
[
str
,
Any
])
->
None
:
"""Restore latest scaling factors and amaxes after recompute forward run."""
# delayed scaling only function, noop for any other recipe
if
not
fp8_meta
[
"recipe"
].
delayed
():
return
fp8_meta
[
"scaling_fwd"
].
amax_history
.
copy_
(
fp8_meta
[
"updated_amax_history_fwd"
])
fp8_meta
[
"scaling_fwd"
].
scale
.
copy_
(
fp8_meta
[
"updated_scale_fwd"
])
@
contextmanager
def
fp8_model_init
(
enabled
:
bool
=
True
,
recipe
:
Optional
[
Recipe
]
=
None
,
preserve_high_precision_init_val
:
bool
=
False
,
)
->
None
:
"""
.. warning::
fp8_model_init is deprecated and will be removed in a future release. Use
quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead.
"""
warnings
.
warn
(
"fp8_model_init is deprecated and will be removed in a future release. "
"Use quantized_model_init("
"enabled=..., recipe=..., preserve_high_precision_init_val=...) instead."
,
category
=
DeprecationWarning
,
stacklevel
=
2
,
)
# Call new implementation.
with
quantized_model_init
(
enabled
=
enabled
,
recipe
=
recipe
,
preserve_high_precision_init_val
=
preserve_high_precision_init_val
,
):
yield
@
contextmanager
def
quantized_model_init
(
enabled
:
bool
=
True
,
recipe
:
Optional
[
Recipe
]
=
None
,
preserve_high_precision_init_val
:
bool
=
False
,
)
->
None
:
"""
Context manager for initialization of quantized parameters.
Example usage:
.. code-block:: python
with quantized_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
# Preserving high precision initial value to initialize master weight
with quantized_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()
Parameters
----------
enabled: bool, default = `True`
when enabled, Transformer Engine modules created inside this `quantized_model_init`
region will hold only quantized copies of its parameters, as opposed to the default
behavior where both higher precision and quantized copies are present. Setting this
option to `True` may result in lower memory consumption and is especially
useful for scenarios like:
* full model training using optimizer with master weights, where the high
precision copies of weights are already present in the optimizer.
* inference, where only the quantized copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default = `None`
Recipe used to create the parameters. If left to None, it uses the default recipe.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize quantized parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to quantized parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using quantized parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters
=
FP8GlobalStateManager
.
FP8_PARAMETERS
_fp8_recipe
=
FP8GlobalStateManager
.
FP8_RECIPE
_high_precision_init_val
=
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager
.
FP8_PARAMETERS
=
enabled
FP8GlobalStateManager
.
FP8_RECIPE
=
get_default_fp8_recipe
()
if
recipe
is
None
else
recipe
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
=
preserve_high_precision_init_val
try
:
yield
finally
:
FP8GlobalStateManager
.
FP8_PARAMETERS
=
_fp8_parameters
FP8GlobalStateManager
.
FP8_RECIPE
=
_fp8_recipe
FP8GlobalStateManager
.
HIGH_PRECISION_INIT_VAL
=
_high_precision_init_val
@
contextmanager
def
fp8_autocast
(
enabled
:
bool
=
True
,
calibrating
:
bool
=
False
,
fp8_recipe
:
Optional
[
Recipe
]
=
None
,
fp8_group
:
Optional
[
dist_group_type
]
=
None
,
_graph
:
bool
=
False
,
)
->
None
:
"""
.. warning::
fp8_autocast is deprecated and will be removed in a future release.
Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead.
"""
warnings
.
warn
(
"fp8_autocast is deprecated and will be removed in a future release. "
"Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead."
,
category
=
DeprecationWarning
,
stacklevel
=
2
,
)
# Call new implementation.
with
autocast
(
enabled
=
enabled
,
calibrating
=
calibrating
,
recipe
=
fp8_recipe
,
amax_reduction_group
=
fp8_group
,
_graph
=
_graph
,
):
yield
@
contextmanager
def
autocast
(
enabled
:
bool
=
True
,
calibrating
:
bool
=
False
,
recipe
:
Optional
[
"Recipe"
]
=
None
,
amax_reduction_group
:
Optional
[
"dist_group_type"
]
=
None
,
_graph
:
bool
=
False
,
)
->
None
:
"""
Context manager for quantization schemes like FP8 or FP4.
.. code-block:: python
with autocast(enabled=True):
out = model(inp)
.. note::
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
with shapes where both dimensions are divisible by 16. In terms of the input to the full
Transformer network, this typically requires padding sequence length to be multiple of 16.
.. note::
When :attr:`recipe.reduce_amax==True`, any module must not be invoked more than once
inside a single `autocast` region. This is unsupported behavior because the amax
reduction is handled during the exit of the `autocast` context. Calling the same
module more than once inside an `autocast` region overrides the amax tensors
before reduction can occur.
Parameters
----------
enabled: bool, default = `True`
whether or not to enable low precision quantization (FP8/FP4).
calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training
using a higher precision.
recipe: recipe.Recipe, default = `None`
recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step.
"""
if
enabled
:
check_recipe_support
(
recipe
)
# Save current state so we always restore it on exit.
fp8_state
=
FP8GlobalStateManager
.
get_autocast_state
()
FP8GlobalStateManager
.
autocast_enter
(
enabled
=
enabled
,
calibrating
=
calibrating
,
fp8_recipe
=
recipe
,
fp8_group
=
amax_reduction_group
,
_graph
=
_graph
,
)
try
:
yield
finally
:
FP8GlobalStateManager
.
set_autocast_state
(
fp8_state
)
FP8GlobalStateManager
.
autocast_exit
(
enabled
,
_graph
=
_graph
)
def
_update_amax_history
(
amax_history
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Update amax history and set next amax to zero."""
if
amax_history
.
shape
[
0
]
>
1
:
new_amax_history
=
torch
.
roll
(
amax_history
,
-
1
,
0
)
amax_history
.
copy_
(
new_amax_history
)
amax_history
[
0
].
fill_
(
0.0
)
return
amax_history
@
torch
.
jit
.
script
def
_default_get_amax_and_update_history
(
amax_history
:
torch
.
Tensor
,
amax_compute_algo
:
str
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Default function to obtain amax from history."""
if
amax_compute_algo
==
"max"
:
amax
=
torch
.
max
(
amax_history
,
dim
=
0
).
values
else
:
# amax_compute_algo == "most_recent"
amax
=
amax_history
[
0
].
clone
()
amax_history
=
_update_amax_history
(
amax_history
)
return
amax_history
,
amax
@
jit_fuser
def
_default_sf_compute
(
amax
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
fp8_max
:
float
,
margin
:
int
,
_fp32_max
:
float
=
torch
.
finfo
(
torch
.
float32
).
max
,
# finfo not available in jitter
)
->
torch
.
Tensor
:
"""Default function to convert amax to scaling factor.
Computing the scaling factor requires consideration of the following scenarios:
1. amax == 0:
No action is possible, set scale to the previous scale (or 1).
2. 0 < amax < tiny_amax
The amax is too tiny that the scale becomes infinite in FP32.
Set scale = FP32_max
3. tiny_amax <= amax < FP32_max:
Set scale = FP8_max (or scaled_max) / amax
4. When amax == inf or amax == nan:
No action is possible, set scale to the previous scale (or 1).
"""
sf
=
(
fp8_max
/
amax
)
/
(
2
**
margin
)
sf
=
torch
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isfinite
(
amax
),
sf
,
scale
)
sf
=
torch
.
where
(
torch
.
isinf
(
sf
),
torch
.
full_like
(
sf
,
_fp32_max
),
sf
)
scale
.
copy_
(
sf
)
return
scale
def
_compute_amax_and_update_history
(
amax_history
:
torch
.
Tensor
,
amax_compute_algo
:
Union
[
Callable
,
str
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Obtain the amax from the history."""
if
callable
(
amax_compute_algo
):
amax
=
amax_compute_algo
(
amax_history
)
amax_history
=
_update_amax_history
(
amax_history
)
return
amax_history
,
amax
return
_default_get_amax_and_update_history
(
amax_history
,
amax_compute_algo
,
)
def
_compute_scaling_factor
(
amax
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
fp8_max
:
float
,
recipe
:
DelayedScaling
,
)
->
torch
.
Tensor
:
"""Convert amax to scaling factor."""
if
recipe
.
scaling_factor_compute_algo
is
None
:
return
_default_sf_compute
(
amax
,
scale
,
fp8_max
,
recipe
.
margin
,
)
return
recipe
.
scaling_factor_compute_algo
(
amax
,
scale
,
fp8_max
,
recipe
)
def
_amax_and_scale_update
(
amax_history
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
fp8_max
:
float
,
recipe
:
DelayedScaling
,
)
->
None
:
"""Updates FP8 meta tensors."""
new_amax_history
,
amax
=
_compute_amax_and_update_history
(
amax_history
,
recipe
.
amax_compute_algo
,
)
new_scale
=
_compute_scaling_factor
(
amax
,
scale
,
fp8_max
,
recipe
)
scale
.
copy_
(
new_scale
)
amax_history
.
copy_
(
new_amax_history
)
def
split_and_copy
(
buffer
:
torch
.
Tensor
,
outputs
:
List
[
torch
.
Tensor
],
chunk_sizes
:
List
[
int
],
)
->
None
:
"""Split `buffer` by `chunk_sizes` and copy into `outputs`."""
splits
=
buffer
.
split
(
chunk_sizes
)
torch
.
_foreach_copy_
(
outputs
,
splits
)
class
RecipeState
(
abc
.
ABC
):
"""Configuration and state for a quantization recipe.
This is a builder class for quantizers, which are in turn builder
classes for quantized tensors.
This class may pack together the state for multiple quantizers,
which is helpful for applying fused kernels with less overhead.
"""
@
staticmethod
def
create
(
recipe
:
Recipe
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
RecipeState
:
"""Factory method to create the state for a quantization recipe
Parameters
----------
recipe: Recipe
Quantization recipe.
mode: {"forward", "backward"}
Training stage where quantization will be performed.
num_quantizers: int, default = 1
Number of quantizers to create state for.
device: torch.device, default = default CUDA device
Device for quantized tensors.
Returns
-------
RecipeState:
Quantization recipe state.
"""
cls
=
None
if
recipe
.
delayed
():
cls
=
DelayedScalingRecipeState
elif
recipe
.
mxfp8
():
cls
=
MXFP8BlockScalingRecipeState
elif
recipe
.
float8_current_scaling
():
cls
=
Float8CurrentScalingRecipeState
elif
recipe
.
float8_block_scaling
():
cls
=
Float8BlockScalingRecipeState
elif
recipe
.
nvfp4
():
cls
=
NVFP4BlockScalingRecipeState
elif
recipe
.
custom
():
cls
=
CustomRecipeState
else
:
raise
ValueError
(
f
"
{
recipe
.
__class__
.
__name__
}
is not supported"
)
return
cls
(
recipe
,
mode
=
mode
,
num_quantizers
=
num_quantizers
,
device
=
device
,
)
@
abc
.
abstractmethod
def
make_quantizers
(
self
)
->
list
:
"""Convert recipe state to quantizers.
Quantizers are builder classes for quantized tensors. They are
typically used to convert a high-precision tensor (e.g. in
FP32 or BF16) into a quantized tensor (e.g. in FP8).
"""
class
DelayedScalingRecipeState
(
RecipeState
):
"""State for FP8 quantization with per-tensor delayed scaling.
Delayed scaling recipe requires a scaling factor (applied when
casting to FP8) and a history of max-abs values ("amax") from
recent FP8 casts for updating the scaling factor. The scale update
is handled externally by `FP8GlobalStateManager`.
"""
recipe
:
DelayedScaling
mode
:
str
dtype
:
tex
.
DType
scale
:
torch
.
Tensor
amax_history
:
torch
.
Tensor
def
__init__
(
self
,
recipe
:
DelayedScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
dtype
=
get_fp8_te_dtype
(
recipe
,
mode
==
"forward"
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
self
.
scale
=
torch
.
ones
(
num_quantizers
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
amax_history
=
torch
.
zeros
(
recipe
.
amax_history_len
,
num_quantizers
,
dtype
=
torch
.
float32
,
device
=
device
,
)
def
make_quantizers
(
self
)
->
list
:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from
.tensor.float8_tensor
import
Float8Quantizer
return
[
Float8Quantizer
(
self
.
scale
[
i
],
self
.
amax_history
[
0
][
i
].
reshape
((
1
,)),
self
.
dtype
)
for
i
in
range
(
self
.
num_quantizers
)
]
class
Float8CurrentScalingRecipeState
(
RecipeState
):
"""Configuration for Per-tensor current scaling quantization.
Per-tensor current quantization does not require state.
"""
recipe
:
Float8CurrentScaling
mode
:
str
dtype
:
tex
.
DType
device
:
torch
.
device
def
__init__
(
self
,
recipe
:
Float8CurrentScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
dtype
=
get_fp8_te_dtype
(
recipe
,
mode
==
"forward"
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
device
def
make_quantizers
(
self
)
->
list
:
from
.tensor.float8_tensor
import
Float8CurrentScalingQuantizer
return
[
Float8CurrentScalingQuantizer
(
self
.
dtype
,
device
=
self
.
device
,
force_pow_2_scales
=
self
.
recipe
.
use_power_2_scales
)
for
i
in
range
(
self
.
num_quantizers
)
]
class
MXFP8BlockScalingRecipeState
(
RecipeState
):
"""Configuration for MXFP8 quantization.
MXFP8 quantization does not require state.
"""
recipe
:
MXFP8BlockScaling
mode
:
str
dtype
:
tex
.
DType
def
__init__
(
self
,
recipe
:
MXFP8BlockScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
dtype
=
get_fp8_te_dtype
(
recipe
,
mode
==
"forward"
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
def
make_quantizers
(
self
)
->
list
:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from
.tensor.mxfp8_tensor
import
MXFP8Quantizer
return
[
MXFP8Quantizer
(
self
.
dtype
)
for
i
in
range
(
self
.
num_quantizers
)]
class
Float8BlockScalingRecipeState
(
RecipeState
):
"""Configuration for Float8BlockScaling quantization.
Float8BlockScaling quantization does not require state,
but different quantizers use different modes.
"""
recipe
:
Float8BlockScaling
mode
:
str
qx_dtype
:
tex
.
DType
qw_dtype
:
tex
.
DType
qgrad_dtype
:
tex
.
DType
def
__init__
(
self
,
recipe
:
Float8BlockScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
qx_dtype
=
get_fp8_te_dtype
(
recipe
,
True
)
self
.
qw_dtype
=
get_fp8_te_dtype
(
recipe
,
True
)
self
.
qgrad_dtype
=
get_fp8_te_dtype
(
recipe
,
False
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
device
def
make_quantizers
(
self
)
->
list
:
# TODO(ksivamani); Find better design for this, adding here to avoid circular import.
from
.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
if
self
.
mode
==
"forward"
:
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward, and doesn't play nicely with QuantizeOp,
# which is not associated with a GEMM.
assert
self
.
num_quantizers
%
3
==
0
# x, w, output per gemm
return
list
(
itertools
.
chain
.
from_iterable
(
[
[
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qx_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
x_block_scaling_dim
,
),
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qw_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_fwd_weight
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_fwd_weight
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
w_block_scaling_dim
,
),
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qx_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_fwd_inp
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_fwd_inp
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
x_block_scaling_dim
,
),
]
for
_
in
range
(
self
.
num_quantizers
//
3
)
]
)
)
assert
self
.
mode
==
"backward"
,
f
"Unexpected mode
{
self
.
mode
}
"
assert
self
.
num_quantizers
%
2
==
0
# grad_output and grad_input per gemm
return
list
(
itertools
.
chain
.
from_iterable
(
[
[
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qgrad_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
grad_block_scaling_dim
,
),
Float8BlockQuantizer
(
fp8_dtype
=
self
.
qgrad_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
self
.
recipe
.
fp8_quant_bwd_grad
.
amax_epsilon
,
force_pow_2_scales
=
self
.
recipe
.
fp8_quant_bwd_grad
.
power_2_scale
,
block_scaling_dim
=
self
.
recipe
.
grad_block_scaling_dim
,
),
]
for
_
in
range
(
self
.
num_quantizers
//
2
)
]
)
)
class
NVFP4BlockScalingRecipeState
(
RecipeState
):
"""Configuration for NVFP4 quantization.
NVFP4 quantization does not require state.
"""
recipe
:
NVFP4BlockScaling
mode
:
str
dtype
:
tex
.
DType
def
__init__
(
self
,
recipe
:
NVFP4BlockScaling
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
self
.
dtype
=
get_fp4_te_dtype
(
recipe
)
# Allocate buffers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
def
make_quantizers
(
self
)
->
list
:
from
.tensor.nvfp4_tensor
import
NVFP4Quantizer
# The index convention (coming from base.py set_meta_tensor)
# is somewhat awkward. It assumes forward quantizers are
# ordered [input, weight, output, ...] and backward quantizers
# are ordered [grad_output, grad_input, ...]. This doesn't
# play nicely with fusible ops: Linear op doesn't own output
# or grad input quantizers, Quantize op only owns input and
# grad output quantizers.
if
self
.
mode
==
"forward"
:
def
_make_quantizer
(
idx
:
int
)
->
NVFP4Quantizer
:
qparams
=
(
self
.
recipe
.
fp4_quant_fwd_weight
if
idx
%
3
==
1
else
self
.
recipe
.
fp4_quant_fwd_inp
)
return
NVFP4Quantizer
(
fp4_dtype
=
self
.
dtype
,
rowwise
=
True
,
columnwise
=
True
,
with_rht
=
qparams
.
random_hadamard_transform
,
with_post_rht_amax
=
qparams
.
random_hadamard_transform
,
with_2d_quantization
=
qparams
.
fp4_2d_quantization
,
stochastic_rounding
=
qparams
.
stochastic_rounding
,
)
return
[
_make_quantizer
(
idx
)
for
idx
in
range
(
self
.
num_quantizers
)]
if
self
.
mode
==
"backward"
:
return
[
NVFP4Quantizer
(
fp4_dtype
=
self
.
dtype
,
rowwise
=
True
,
columnwise
=
True
,
with_rht
=
self
.
recipe
.
fp4_quant_bwd_grad
.
random_hadamard_transform
,
with_post_rht_amax
=
self
.
recipe
.
fp4_quant_bwd_grad
.
random_hadamard_transform
,
with_2d_quantization
=
self
.
recipe
.
fp4_quant_bwd_grad
.
fp4_2d_quantization
,
stochastic_rounding
=
self
.
recipe
.
fp4_quant_bwd_grad
.
stochastic_rounding
,
)
for
_
in
range
(
self
.
num_quantizers
)
]
raise
RuntimeError
(
f
"Unexpected recipe mode (
{
self
.
mode
}
)"
)
class
CustomRecipeState
(
RecipeState
):
"""State for CustomRecipe: produce quantizers per tensor."""
recipe
:
CustomRecipe
mode
:
str
num_quantizers
:
int
device
:
Optional
[
torch
.
device
]
def
__init__
(
self
,
recipe
:
CustomRecipe
,
*
,
mode
:
str
,
num_quantizers
:
int
=
1
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
None
:
self
.
recipe
=
recipe
self
.
mode
=
mode
self
.
num_quantizers
=
num_quantizers
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
device
if
getattr
(
recipe
,
"qfactory"
,
None
)
is
None
:
raise
ValueError
(
"CustomRecipe requires `qfactory`."
)
def
make_quantizers
(
self
)
->
list
:
qfactory
=
self
.
recipe
.
qfactory
out
=
[]
# TODO(negvet): make_quantizers() should take roles from the operation
# Hardcode linear-specific roles for now
roles
:
List
[
str
]
if
self
.
mode
==
"forward"
:
roles
=
[
(
"linear_input"
,
"linear_weight"
,
"linear_output"
)[
i
%
3
]
for
i
in
range
(
self
.
num_quantizers
)
]
elif
self
.
mode
==
"backward"
:
roles
=
[
(
"linear_grad_output"
,
"linear_grad_input"
)[
i
%
2
]
for
i
in
range
(
self
.
num_quantizers
)
]
else
:
roles
=
[
"unknown"
]
*
self
.
num_quantizers
for
i
in
range
(
self
.
num_quantizers
):
# Get quantizer from the user defined factory
quantizer
=
qfactory
(
roles
[
i
])
out
.
append
(
quantizer
)
return
out
transformer_engine/pytorch/setup.py
View file @
063ef88d
...
...
@@ -45,7 +45,7 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from
build_tools.build_ext
import
get_build_ext
from
build_tools.utils
import
copy_common_headers
from
build_tools.utils
import
copy_common_headers
,
min_python_version_str
from
build_tools.te_version
import
te_version
from
build_tools.pytorch
import
(
setup_pytorch_extension
,
...
...
@@ -152,6 +152,7 @@ if __name__ == "__main__":
description
=
"Transformer acceleration library - Torch Lib"
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
,
"bdist_wheel"
:
CachedWheelsCommand
},
python_requires
=
f
">=
{
min_python_version_str
()
}
"
,
install_requires
=
install_requirements
(),
tests_require
=
test_requirements
(),
)
...
...
transformer_engine/pytorch/tensor/__init__.py
View file @
063ef88d
...
...
@@ -6,12 +6,42 @@
import
torch
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
from
.quantized_tensor
import
(
QuantizedTensorStorage
,
QuantizedTensor
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
)
from
.storage.float8_tensor_storage
import
Float8TensorStorage
from
.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
from
.float8_tensor
import
Float8Tensor
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8Quantizer
from
.float8_blockwise_tensor
import
Float8BlockwiseQTensor
,
Float8BlockQuantizer
from
.nvfp4_tensor
import
NVFP4Tensor
,
NVFP4Quantizer
from
.utils
import
cast_master_weights_to_fp8
,
replace_raw_data
__all__
=
[
"QuantizedTensor"
,
"Quantizer"
,
"Float8Quantizer"
,
"Float8CurrentScalingQuantizer"
,
"MXFP8Quantizer"
,
"Float8BlockQuantizer"
,
"NVFP4Quantizer"
,
"QuantizedTensorStorage"
,
"Float8TensorStorage"
,
"MXFP8TensorStorage"
,
"Float8BlockwiseQTensorStorage"
,
"NVFP4TensorStorage"
,
"QuantizedTensor"
,
"Float8Tensor"
,
"MXFP8Tensor"
,
"Float8BlockwiseQTensor"
,
"NVFP4Tensor"
,
"prepare_for_saving"
,
"restore_from_saved"
,
]
...
...
@@ -48,21 +78,16 @@ def get_all_tensor_types():
"""
Get all tensor-like types that can be used in TE.
"""
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
,
Float8TensorBase
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8TensorBase
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockwiseQTensor
,
Float8BlockwiseQTensorBase
,
)
all_tensor_types
=
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
,
Float8Tensor
,
Float8Tensor
Bas
e
,
Float8Tensor
Storag
e
,
MXFP8Tensor
,
MXFP8Tensor
Bas
e
,
MXFP8Tensor
Storag
e
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensorBase
,
Float8BlockwiseQTensorStorage
,
NVFP4Tensor
,
NVFP4TensorStorage
,
]
return
all_tensor_types
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
063ef88d
...
...
@@ -14,8 +14,12 @@ from transformer_engine_torch import DType as TE_DType
from
transformer_engine_torch
import
Float8BlockScaleTensorFormat
from
transformer_engine.common.recipe
import
Float8BlockScaling
,
Recipe
from
._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
_IdentityFunc
,
)
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
...
...
@@ -104,6 +108,10 @@ class Float8BlockQuantizer(Quantizer):
dst
.
_fp8_dtype
=
self
.
dtype
return
dst
def
quantize_impl
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Quantize tensor implementation"""
return
tex
.
quantize
(
tensor
,
self
)
def
get_scale_shape
(
self
,
shape
:
Iterable
[
int
],
columnwise
:
bool
)
->
Tuple
[
int
,
int
]:
"""Calculate the shape of the scaling tensor for blockwise quantization.
...
...
@@ -273,7 +281,7 @@ class Float8BlockQuantizer(Quantizer):
return
Float8BlockScaling
class
Float8BlockwiseQTensor
(
Float8BlockwiseQTensor
Bas
e
,
QuantizedTensor
):
class
Float8BlockwiseQTensor
(
Float8BlockwiseQTensor
Storag
e
,
QuantizedTensor
):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
The tensor presents as having a standard, higher-precision dtype,
...
...
@@ -298,7 +306,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
holds configuration about quantization and dequantization modes.
"""
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensor
Bas
e with positional args,
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensor
Storag
e with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def
__new__
(
cls
,
...
...
@@ -337,15 +345,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
f
" data_format=
{
self
.
_data_format
}
"
)
def
_get_quantizer
(
self
)
->
Quantizer
:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
assert
self
.
_quantizer
is
not
None
return
self
.
_quantizer
def
quantize_
(
self
,
tensor
:
torch
.
Tensor
,
...
...
@@ -364,8 +363,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""
if
isinstance
(
tensor
,
QuantizedTensor
):
return
self
.
quantize_
(
tensor
.
dequantize
())
self
.
_get_quantizer
().
update_quantized
(
tensor
,
self
,
noop_flag
=
noop_flag
)
return
self
return
super
().
quantize_
(
tensor
,
noop_flag
=
noop_flag
)
def
dequantize
(
self
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
Tensor
:
"""
...
...
@@ -408,6 +406,21 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring
return
_ReshapeFunc
.
apply
(
self
,
shape
)
def
untyped_storage
(
self
)
->
torch
.
UntypedStorage
:
"""Return the underlying UntypedStorage of the FP8 data.
Note that FP8 block-scaled tensor may involve multiple
buffers: row-wise FP8 data, row-wise scales, column-wise FP8
data, column-wise scales. The UntypedStorage of the row-wise
FP8 data is returned if it exists, and otherwise the
UntypedStorage of the column-wise FP8 data.
"""
data
=
self
.
_rowwise_data
if
self
.
_rowwise_data
is
not
None
else
self
.
_columnwise_data
if
data
is
not
None
:
return
data
.
untyped_storage
()
return
torch
.
UntypedStorage
(
0
,
device
=
self
.
device
)
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
...
...
@@ -432,6 +445,19 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
)
return
Float8BlockwiseQTensor
.
make_like
(
tensor
)
# record stream op
if
func
==
torch
.
ops
.
aten
.
record_stream
.
default
:
qt
,
stream
=
args
for
t
in
(
qt
.
_rowwise_data
,
qt
.
_columnwise_data
,
qt
.
_rowwise_scale_inv
,
qt
.
_columnwise_scale_inv
,
):
if
t
is
not
None
and
t
.
is_cuda
:
t
.
record_stream
(
stream
)
return
None
# Default case
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
063ef88d
...
...
@@ -13,8 +13,12 @@ from transformer_engine_torch import DType as TE_DType
from
transformer_engine.common.recipe
import
DelayedScaling
,
Float8CurrentScaling
,
Recipe
from
..utils
import
canonicalize_process_group
,
devices_match
from
._internal.float8_tensor_base
import
Float8TensorBase
,
_FromFloat8Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
.storage.float8_tensor_storage
import
Float8TensorStorage
,
_FromFloat8Func
from
.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
_IdentityFunc
,
)
from
..constants
import
dist_group_type
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8_tensorwise
...
...
@@ -90,6 +94,10 @@ class Float8Quantizer(Quantizer):
return
dst
def
quantize_impl
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Quantize tensor implementation"""
return
tex
.
quantize
(
tensor
,
self
)
def
make_empty
(
self
,
shape
:
Iterable
[
int
],
...
...
@@ -148,7 +156,7 @@ class Float8Quantizer(Quantizer):
torch
.
float8_e5m2fnuz
,
]
if
internal
:
return
Float8Tensor
Bas
e
(
return
Float8Tensor
Storag
e
(
data
=
data
,
fp8_scale_inv
=
1
/
self
.
scale
,
fp8_dtype
=
self
.
dtype
,
...
...
@@ -216,6 +224,8 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax
:
torch
.
Tensor
"""FP8 datatype"""
dtype
:
TE_DType
"""amax update options"""
use_existing_amax
:
bool
"""amax reduction options"""
with_amax_reduction
:
bool
amax_reduction_group
:
Optional
[
dist_group_type
]
...
...
@@ -230,6 +240,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
*
,
rowwise
:
bool
=
True
,
columnwise
:
bool
=
True
,
use_existing_amax
:
bool
=
False
,
with_amax_reduction
:
bool
=
False
,
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
force_pow_2_scales
:
bool
=
False
,
...
...
@@ -239,6 +250,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
self
.
scale
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8_tensorwise
else
fp8_dtype
self
.
use_existing_amax
=
use_existing_amax
self
.
with_amax_reduction
=
with_amax_reduction
self
.
amax_reduction_group
=
amax_reduction_group
self
.
force_pow_2_scales
=
force_pow_2_scales
...
...
@@ -268,6 +280,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
return
dst
def
quantize_impl
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Quantize tensor implementation"""
return
tex
.
quantize
(
tensor
,
self
)
def
make_empty
(
self
,
shape
:
Iterable
[
int
],
...
...
@@ -330,7 +346,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
torch
.
float8_e5m2fnuz
,
]
if
internal
:
return
Float8Tensor
Bas
e
(
return
Float8Tensor
Storag
e
(
data
=
data
,
fp8_scale_inv
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
data
.
device
),
fp8_dtype
=
self
.
dtype
,
...
...
@@ -385,7 +401,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
return
True
class
Float8Tensor
(
Float8Tensor
Bas
e
,
QuantizedTensor
):
class
Float8Tensor
(
Float8Tensor
Storag
e
,
QuantizedTensor
):
"""Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype,
...
...
@@ -440,19 +456,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
return
_FromFloat8Func
.
apply
(
self
,
dtype
)
return
_FromFloat8Func
.
forward
(
None
,
self
,
dtype
)
def
_get_quantizer
(
self
)
->
Quantizer
:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if
self
.
_quantizer
is
not
None
:
return
self
.
_quantizer
# Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling)
raise
ValueError
(
"Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable"
)
def
quantize_
(
self
,
tensor
:
torch
.
Tensor
,
...
...
@@ -471,8 +474,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""
if
isinstance
(
tensor
,
QuantizedTensor
):
return
self
.
quantize_
(
tensor
.
dequantize
(),
noop_flag
=
noop_flag
)
self
.
_get_quantizer
().
update_quantized
(
tensor
,
self
,
noop_flag
=
noop_flag
)
return
self
return
super
().
quantize_
(
tensor
,
noop_flag
=
noop_flag
)
def
detach
(
self
)
->
Float8Tensor
:
# pylint: disable=missing-function-docstring
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
063ef88d
...
...
@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Tensor class with FP8 data"""
"""Tensor class with
MX
FP8 data"""
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
math
...
...
@@ -16,8 +16,12 @@ from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe
from
..constants
import
MXFP8_BLOCK_SCALING_SIZE
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
from
._internal.mxfp8_tensor_base
import
MXFP8TensorBase
,
_FromMXFP8Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
,
_FromMXFP8Func
from
.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
_IdentityFunc
,
)
aten
=
torch
.
ops
.
aten
...
...
@@ -67,6 +71,10 @@ class MXFP8Quantizer(Quantizer):
return
dst
def
quantize_impl
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Quantize tensor implementation"""
return
tex
.
quantize
(
tensor
,
self
)
def
is_quantizable
(
self
,
inp
:
torch
.
Tensor
)
->
bool
:
"""Returns whether or not given inp can be quantized"""
if
inp
.
ndim
<
2
:
...
...
@@ -161,14 +169,14 @@ class MXFP8Quantizer(Quantizer):
data
,
scale_inv
=
torch
.
ops
.
tex
.
mxfp8_quantize
(
tensor
)
return
self
.
create_tensor_from_data
(
data
,
scale_inv
,
fake_dtype
=
torch
.
float32
)
def
onnx_dequantize
(
self
,
tensor
:
Union
[
MXFP8Tensor
Bas
e
,
MXFP8Tensor
])
->
torch
.
Tensor
:
def
onnx_dequantize
(
self
,
tensor
:
Union
[
MXFP8Tensor
Storag
e
,
MXFP8Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
tex
.
mxfp8_dequantize
(
tensor
.
_rowwise_data
,
tensor
.
_rowwise_scale_inv
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
MXFP8BlockScaling
class
MXFP8Tensor
(
MXFP8Tensor
Bas
e
,
QuantizedTensor
):
class
MXFP8Tensor
(
MXFP8Tensor
Storag
e
,
QuantizedTensor
):
"""Experimental tensor class with FP8 data
The tensor presents as having a standard, higher-precision dtype,
...
...
@@ -186,14 +194,13 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
precision. Can be inferred from fp8_meta if
provided.
precision.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype.
"""
# NOTE: We reorder the *args so that we can instantiate a MXFP8Tensor
Bas
e with positional args,
# NOTE: We reorder the *args so that we can instantiate a MXFP8Tensor
Storag
e with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def
__new__
(
cls
,
...
...
@@ -237,17 +244,9 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
return
_FromMXFP8Func
.
apply
(
self
,
dtype
)
return
_FromMXFP8Func
.
forward
(
None
,
self
,
dtype
)
def
_get_quantizer
(
self
)
->
Quantizer
:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if
self
.
_quantizer
is
not
None
:
return
self
.
_quantizer
return
MXFP8Quantizer
(
fp8_dtype
=
self
.
_fp8_dtype
,
)
def
_build_default_quantizer
(
self
)
->
Optional
[
Quantizer
]:
"""Build default quantizer for the tensor"""
return
MXFP8Quantizer
(
fp8_dtype
=
self
.
_fp8_dtype
)
def
quantize_
(
self
,
...
...
@@ -267,8 +266,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
"""
if
isinstance
(
tensor
,
QuantizedTensor
):
return
self
.
quantize_
(
tensor
.
dequantize
())
self
.
_get_quantizer
().
update_quantized
(
tensor
,
self
,
noop_flag
=
noop_flag
)
return
self
return
super
().
quantize_
(
tensor
,
noop_flag
=
noop_flag
)
def
detach
(
self
)
->
MXFP8Tensor
:
# pylint: disable=missing-function-docstring
...
...
transformer_engine/pytorch/tensor/nvfp4_tensor.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with NVFP4 data"""
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
math
from
typing
import
Optional
,
Tuple
,
Union
import
functools
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
NVFP4BlockScaling
,
Recipe
from
..constants
import
NVFP4_BLOCK_SCALING_SIZE
,
dist_group_type
from
..utils
import
(
canonicalize_process_group
,
devices_match
,
round_up_to_nearest_multiple
,
)
from
.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
,
_FromNVFP4Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
aten
=
torch
.
ops
.
aten
def
get_no_random_sign_vector
()
->
torch
.
Tensor
:
"""Non-random sign vector for Hadamard transform."""
return
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
)
def
get_sign_from_vector
(
vector
:
torch
.
Tensor
)
->
int
:
"""Convert sign vector to bitmask.
Used for random Hadamard transform.
"""
mask
=
0
for
i
,
v
in
enumerate
(
vector
):
mask
|=
(
v
==
-
1
)
<<
i
return
mask
def
get_wgrad_sign_vector
()
->
torch
.
Tensor
:
"""Hard-coded random signs for Hadamard transform.
https://xkcd.com/221/
"""
return
torch
.
tensor
(
[
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
-
1
],
dtype
=
torch
.
float32
,
)
def
get_hadamard_matrix
(
hadamard_dimension
:
int
)
->
torch
.
Tensor
:
"""Construct a 16x16 Hadamard matrix."""
assert
hadamard_dimension
==
16
,
"Only hadamard dimension 16 is supported."
hadamard_scale
=
1
/
math
.
sqrt
(
hadamard_dimension
)
return
(
torch
.
tensor
(
[
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
],
[
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
],
[
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
],
[
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
],
[
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
],
[
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
],
[
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
],
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
],
[
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
],
[
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
],
[
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
],
[
1
,
1
,
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
1
,
1
],
[
1
,
-
1
,
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
],
[
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
1
,
1
,
-
1
,
-
1
],
[
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
],
],
dtype
=
torch
.
float32
,
)
*
hadamard_scale
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
get_rht_matrix
(
with_random_sign_mask
:
bool
)
->
torch
.
Tensor
:
"""Construct matrix used in random Hadamard transform."""
hadamard_dimension
=
16
if
with_random_sign_mask
:
signs
=
get_wgrad_sign_vector
()
else
:
signs
=
get_no_random_sign_vector
()
sign_matrix
=
signs
*
torch
.
eye
(
hadamard_dimension
,
dtype
=
torch
.
float32
)
rht_matrix
=
sign_matrix
@
get_hadamard_matrix
(
hadamard_dimension
)
return
rht_matrix
.
to
(
dtype
=
torch
.
bfloat16
).
cuda
()
@
functools
.
lru_cache
(
maxsize
=
None
)
def
get_random_sign_mask_for_rht
(
with_random_sign_mask
:
bool
)
->
int
:
"""Sign mask for random Hadamard transform."""
if
with_random_sign_mask
:
return
get_sign_from_vector
(
get_wgrad_sign_vector
())
return
0
class
NVFP4Quantizer
(
Quantizer
):
"""Builder class for NVFP4 tensors with NV block scaling"""
dtype
:
TE_DType
"""Random Hadamard Transform"""
with_rht
:
bool
with_post_rht_amax
:
bool
"""amax reduction options"""
with_amax_reduction
:
bool
amax_reduction_group
:
Optional
[
dist_group_type
]
"""2D block scaling, only applicable for weights."""
with_2d_quantization
:
bool
"""Stochastic rounding, only applicable for gradients."""
stochastic_rounding
:
bool
"""RHT matrix random sign mask"""
rht_matrix_random_sign_mask_t
:
int
rht_matrix
:
torch
.
Tensor
def
__init__
(
self
,
fp4_dtype
:
TE_DType
=
tex
.
DType
.
kFloat4E2M1
,
rowwise
:
bool
=
True
,
columnwise
:
bool
=
True
,
with_amax_reduction
:
bool
=
False
,
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
with_rht
:
bool
=
False
,
with_post_rht_amax
:
bool
=
False
,
with_2d_quantization
:
bool
=
False
,
stochastic_rounding
:
bool
=
False
,
with_random_sign_mask
:
bool
=
True
,
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
fp4_dtype
self
.
with_rht
=
with_rht
self
.
with_post_rht_amax
=
with_post_rht_amax
self
.
with_amax_reduction
=
with_amax_reduction
self
.
amax_reduction_group
=
amax_reduction_group
self
.
with_2d_quantization
=
with_2d_quantization
self
.
stochastic_rounding
=
stochastic_rounding
self
.
rht_matrix_random_sign_mask_t
=
get_random_sign_mask_for_rht
(
with_random_sign_mask
)
self
.
rht_matrix
=
get_rht_matrix
(
with_random_sign_mask
)
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
dst
:
QuantizedTensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
QuantizedTensor
:
assert
isinstance
(
dst
,
NVFP4Tensor
),
f
"Cannot store quantized NVFP4 in
{
type
(
dst
)
}
type."
# Make sure input is in expected format
if
not
devices_match
(
src
.
device
,
dst
.
device
):
src
=
src
.
to
(
device
=
dst
.
device
)
if
not
src
.
is_contiguous
():
src
=
src
.
contiguous
()
# Launch cast kernel
tex
.
quantize
(
src
,
self
,
dst
,
noop_flag
)
return
dst
def
quantize_impl
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Quantize tensor implementation"""
return
tex
.
quantize
(
tensor
,
self
)
def
is_quantizable
(
self
,
inp
:
torch
.
Tensor
)
->
bool
:
"""Returns whether or not given inp can be quantized"""
if
inp
.
ndim
<
2
:
return
False
if
inp
.
shape
[
-
1
]
%
NVFP4_BLOCK_SCALING_SIZE
!=
0
:
return
False
if
math
.
prod
(
inp
.
shape
[:
-
1
])
%
NVFP4_BLOCK_SCALING_SIZE
!=
0
:
return
False
return
True
def
get_scale_shape
(
self
,
shape
:
Iterable
[
int
],
columnwise
:
bool
)
->
Tuple
[
int
,
int
]:
"""Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For NVFP4 1D blockwise quantization, blocksize is 16
- If columnwise: (round_to_multiple(K, 128), round_to_multiple(roundup(M / 16), 4))
- If rowwise: (round_to_multiple(M, 128), round_to_multiple(roundup(K / 16), 4))
Swizzle kernel will be performed before GEMM to suit the need of CuBLAS.
CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
"""
M
,
K
=
1
,
1
M
=
math
.
prod
(
shape
[:
-
1
])
K
=
shape
[
-
1
]
if
columnwise
:
outer
=
round_up_to_nearest_multiple
(
K
,
128
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
M
/
NVFP4_BLOCK_SCALING_SIZE
),
4
)
return
(
outer
,
inner
)
# rowwise
outer
=
round_up_to_nearest_multiple
(
M
,
128
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
K
/
NVFP4_BLOCK_SCALING_SIZE
),
4
)
return
(
outer
,
inner
)
@
staticmethod
def
get_columnwise_shape
(
shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
"""Calculate the shape of a tensor after columnwise quantization.
For NVFP4 columnwise quantization, it's performing 16x1 quantization block scaling.
Parameters
----------
shape : Iterable[int]
Original shape of the tensor
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
"""
if
len
(
shape
)
==
0
:
return
tuple
()
# and then after AG, a reorganize kernel will be called to restore the shape
colwise_shape
=
[
shape
[
-
1
]]
for
i
in
range
(
len
(
shape
)
-
1
):
colwise_shape
.
append
(
shape
[
i
])
return
tuple
(
colwise_shape
)
@
staticmethod
def
convert_shape_for_fp4
(
shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
"""Convert shape for FP4 data by dividing the last dimension by 2"""
shape
=
list
(
shape
)
shape
[
-
1
]
=
shape
[
-
1
]
//
2
return
tuple
(
shape
)
def
make_empty
(
self
,
shape
:
Iterable
[
int
],
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
False
,
)
->
NVFP4Tensor
:
# Canonicalize tensor attributes
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
assert
shape
[
-
1
]
%
NVFP4_BLOCK_SCALING_SIZE
==
0
,
(
f
"Incorrect shape
{
shape
}
for NVFP4. Tensor dims must be divisible by"
f
"
{
NVFP4_BLOCK_SCALING_SIZE
}
"
)
flat_first_dim
=
math
.
prod
(
shape
[:
-
1
])
assert
flat_first_dim
%
NVFP4_BLOCK_SCALING_SIZE
==
0
,
(
f
"Incorrect shape
{
shape
}
for NVFP4. Tensor dims must be divisible by"
f
"
{
NVFP4_BLOCK_SCALING_SIZE
}
"
)
# Allocate FP4 data
data
=
None
scale_inv
=
None
amax_rowwise
=
None
if
self
.
rowwise_usage
:
data
=
torch
.
empty
(
self
.
convert_shape_for_fp4
(
shape
),
dtype
=
torch
.
uint8
,
device
=
device
)
scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
False
)
scale_inv
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Allocate per tensor scale inverse. FP32 format.
amax_rowwise
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# Allocate FP8 data transpose if needed
columnwise_data
=
None
columnwise_scale_inv
=
None
amax_columnwise
=
None
if
self
.
columnwise_usage
:
# enforce 2D shape to avoid [S, B, H] shape and B and be 1
# and the transposed shape is [H, S, B], so divide last dim by 2 gives zero
shape_2d
=
tuple
([
flat_first_dim
,
shape
[
-
1
]])
columnwise_data
=
torch
.
empty
(
self
.
convert_shape_for_fp4
(
self
.
get_columnwise_shape
(
shape_2d
)),
dtype
=
torch
.
uint8
,
device
=
device
,
)
columnwise_scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
True
)
columnwise_scale_inv
=
torch
.
empty
(
columnwise_scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
amax_columnwise
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# Construct FP8 tensor
return
NVFP4Tensor
(
shape
=
shape
,
dtype
=
dtype
,
rowwise_data
=
data
,
rowwise_scale_inv
=
scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
amax_rowwise
=
amax_rowwise
,
amax_columnwise
=
amax_columnwise
,
fp4_dtype
=
self
.
dtype
,
quantizer
=
self
,
requires_grad
=
requires_grad
,
)
def
calibrate
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
pass
# Calibration is no-op
def
_canonicalized_amax_reduction_group
(
self
)
->
dist_group_type
:
"""Get process group for amax reduction"""
return
canonicalize_process_group
(
self
.
amax_reduction_group
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
NVFP4BlockScaling
class
NVFP4Tensor
(
NVFP4TensorStorage
,
QuantizedTensor
):
"""Quantized tensor class with FP4 data
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP4. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
rowwise_data: torch.Tensor
Raw FP4 data in a uint8 tensor (rowwise layout).
rowwise_scale_inv: torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP4, i.e. the scaling factor that must
be applied when casting from FP4 to higher
precision (rowwise).
columnwise_data: torch.Tensor, optional
Raw FP4 data in a uint8 tensor (columnwise layout).
columnwise_scale_inv: torch.Tensor, optional
Reciprocal of the scaling factor for columnwise FP4 data.
amax_rowwise: torch.Tensor, optional
Rowwise amax tracking tensor.
amax_columnwise: torch.Tensor, optional
Columnwise amax tracking tensor.
fp4_dtype: TE_DType
The FP4 data type used for quantization.
quantizer: Quantizer
The quantizer instance used for this tensor.
dtype: torch.dtype, default = torch.float32
Nominal tensor datatype, used in dequantize.
"""
# NOTE: We reorder the *args so that we can instantiate a NVFP4TensorStorage with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def
__new__
(
cls
,
*
args
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
rowwise_scale_inv
:
Optional
[
torch
.
Tensor
],
columnwise_data
:
Optional
[
torch
.
Tensor
],
columnwise_scale_inv
:
Optional
[
torch
.
Tensor
],
amax_rowwise
:
Optional
[
torch
.
Tensor
],
amax_columnwise
:
Optional
[
torch
.
Tensor
],
fp4_dtype
:
TE_DType
,
quantizer
:
Quantizer
,
**
kwargs
,
):
instance
=
super
().
__new__
(
cls
,
rowwise_data
,
rowwise_scale_inv
,
columnwise_data
,
columnwise_scale_inv
,
amax_rowwise
,
amax_columnwise
,
fp4_dtype
,
quantizer
,
*
args
,
**
kwargs
,
)
return
instance
def
__repr__
(
self
,
*
,
tensor_contents
=
None
):
return
f
"NVFP4Tensor, data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
)"
def
dequantize
(
self
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
Tensor
:
"""
Construct plain PyTorch tensor from NVFP4Tensor
By default the resulting tensor's dtype is the
NVFP4Tensor's nominal dtype.
"""
# Convert PyTorch dtype to TE dtype
if
dtype
is
None
:
dtype
=
self
.
dtype
if
torch
.
is_grad_enabled
():
return
_FromNVFP4Func
.
apply
(
self
,
dtype
)
return
_FromNVFP4Func
.
forward
(
None
,
self
,
dtype
)
def
_get_quantizer
(
self
)
->
Quantizer
:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if
self
.
_quantizer
is
not
None
:
return
self
.
_quantizer
return
NVFP4Quantizer
()
def
quantize_
(
self
,
tensor
:
torch
.
Tensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
NVFP4Tensor
:
"""Update FP8 data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if
isinstance
(
tensor
,
QuantizedTensor
):
return
self
.
quantize_
(
tensor
.
dequantize
())
self
.
_get_quantizer
().
update_quantized
(
tensor
,
self
,
noop_flag
=
noop_flag
)
return
self
def
detach
(
self
)
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
# TODO(ksivamani): Fix the detach bug
return
NVFP4Tensor
.
make_like
(
self
)
def
clone
(
self
)
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
assert
self
.
_rowwise_data
is
not
None
rowwise_data
=
self
.
_rowwise_data
.
detach
().
clone
()
columnwise_data
=
None
if
self
.
_columnwise_data
is
not
None
:
columnwise_data
=
self
.
_columnwise_data
.
detach
().
clone
()
return
_IdentityFunc
.
apply
(
self
,
{
"rowwise_data"
:
rowwise_data
,
"columnwise_data"
:
columnwise_data
,
},
)
def
view
(
self
,
*
shape
:
Tuple
[
int
])
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
return
_ViewFunc
.
apply
(
self
,
shape
)
def
reshape
(
self
,
*
shape
:
Tuple
[
int
])
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
return
_ReshapeFunc
.
apply
(
self
,
shape
)
def
contiguous
(
self
,
memory_format
:
torch
.
memory_format
=
torch
.
contiguous_format
,
)
->
NVFP4Tensor
:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_data
.
is_contiguous
(
memory_format
=
memory_format
):
return
self
if
self
.
_columnwise_data
is
not
None
and
self
.
_columnwise_data
.
is_contiguous
(
memory_format
=
memory_format
):
return
self
raise
ValueError
(
"NVFP4Tensor does not support different memory formats!"
)
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
# View op
if
func
==
aten
.
view
.
default
:
if
len
(
args
)
!=
2
:
raise
RuntimeError
(
"Unexpected args for view op (expected 2 args, got {len(args)})"
)
tensor
=
args
[
0
]
shape
=
args
[
1
]
if
shape
==
list
(
tensor
.
size
()):
return
tensor
.
detach
()
return
tensor
.
view
(
shape
)
# NVFP4 dequantize not supported. Add manual support for needed funcs.
if
func
in
(
aten
.
empty_like
.
default
,
aten
.
zero_
.
default
):
tensor
=
args
[
0
]
data_init_func
=
torch
.
zeros_like
if
func
==
aten
.
zero_
.
default
else
torch
.
empty_like
scale_inv_init_func
=
(
torch
.
ones_like
if
func
==
aten
.
zero_
.
default
else
torch
.
empty_like
)
if
tensor
.
_rowwise_data
is
not
None
:
rowwise_data
=
data_init_func
(
tensor
.
_rowwise_data
)
rowwise_scale_inv
=
scale_inv_init_func
(
tensor
.
_rowwise_scale_inv
)
amax_rowwise
=
torch
.
zeros_like
(
tensor
.
_amax_rowwise
)
else
:
rowwise_data
,
rowwise_scale_inv
,
amax_rowwise
=
None
,
None
,
None
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_data
=
data_init_func
(
tensor
.
_columnwise_data
)
columnwise_scale_inv
=
scale_inv_init_func
(
tensor
.
_columnwise_scale_inv
)
amax_columnwise
=
torch
.
zeros_like
(
tensor
.
_amax_columnwise
)
else
:
columnwise_data
,
columnwise_scale_inv
,
amax_columnwise
=
(
None
,
None
,
None
,
)
return
NVFP4Tensor
(
shape
=
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
fp4_dtype
=
tensor
.
_fp4_dtype
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
amax_rowwise
=
amax_rowwise
,
amax_columnwise
=
amax_columnwise
,
quantizer
=
tensor
.
_quantizer
,
requires_grad
=
tensor
.
requires_grad
,
)
# Default case
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
@
classmethod
def
_make_in_reduce_ex
(
cls
,
shape
:
torch
.
Size
,
rowwise_data
:
torch
.
Tensor
,
rowwise_scale_inv
:
torch
.
Tensor
,
columnwise_data
:
torch
.
Tensor
,
columnwise_scale_inv
:
torch
.
Tensor
,
amax_rowwise
:
torch
.
Tensor
,
amax_columnwise
:
torch
.
Tensor
,
fp4_dtype
:
TE_DType
,
dtype
:
torch
.
dtype
,
quantizer
:
Quantizer
,
)
->
NVFP4Tensor
:
"""Build NVFP4Tensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return
NVFP4Tensor
(
shape
=
shape
,
dtype
=
dtype
,
fp4_dtype
=
fp4_dtype
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
amax_rowwise
=
amax_rowwise
,
amax_columnwise
=
amax_columnwise
,
quantizer
=
quantizer
,
requires_grad
=
False
,
)
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
"""Custom pickling"""
return
(
NVFP4Tensor
.
_make_in_reduce_ex
,
(
self
.
shape
,
self
.
_rowwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_data
,
self
.
_columnwise_scale_inv
,
self
.
_amax_rowwise
,
self
.
_amax_columnwise
,
self
.
_fp4_dtype
,
self
.
dtype
,
self
.
_quantizer
,
),
)
def
_get_data
(
self
)
->
NVFP4Tensor
:
"""Get tensor data property"""
return
super
().
data
@
torch
.
no_grad
()
def
_set_data
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""Set tensor data property
Just takes FP8 data if setting from a NVFP4Tensor. Otherwise
casts to FP8.
"""
# Tensor device
new_device
=
tensor
.
device
if
tensor
.
is_cuda
else
self
.
device
if
not
devices_match
(
new_device
,
tensor
.
device
):
tensor
=
tensor
.
to
(
device
=
new_device
)
# Just copy FP8 data if other tensor is NVFP4Tensor
if
isinstance
(
tensor
,
NVFP4Tensor
):
if
(
# pylint: disable=too-many-boolean-expressions
self
.
size
()
!=
tensor
.
size
()
or
self
.
stride
()
!=
tensor
.
stride
()
or
self
.
storage_offset
()
!=
tensor
.
storage_offset
()
or
self
.
dtype
!=
tensor
.
dtype
or
self
.
layout
!=
tensor
.
layout
or
not
devices_match
(
self
.
device
,
new_device
)
):
dummy_tensor
=
torch
.
Tensor
.
_make_wrapper_subclass
(
NVFP4Tensor
,
tensor
.
size
(),
strides
=
tensor
.
stride
(),
storage_offset
=
tensor
.
storage_offset
(),
dtype
=
tensor
.
dtype
,
layout
=
tensor
.
layout
,
requires_grad
=
tensor
.
requires_grad
,
device
=
new_device
,
)
# pylint: disable=unnecessary-dunder-call
super
(
NVFP4Tensor
,
type
(
self
)).
data
.
__set__
(
self
,
dummy_tensor
)
self
.
_rowwise_data
=
tensor
.
_rowwise_data
self
.
_columnwise_data
=
tensor
.
_columnwise_data
self
.
_quantizer
=
tensor
.
_quantizer
self
.
_rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
self
.
_columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
self
.
_amax_rowwise
=
tensor
.
_amax_rowwise
self
.
_amax_columnwise
=
tensor
.
_amax_columnwise
return
# Quantize to FP8
assert
self
.
_quantizer
is
not
None
,
"Can't quantize without a quantizer"
self
.
_quantizer
.
update_quantized
(
tensor
,
self
)
if
self
.
requires_grad
!=
tensor
.
requires_grad
:
self
.
requires_grad_
(
requires_grad
=
tensor
.
requires_grad
)
# Cast to FP8 when setting NVFP4Tensor.data
data
=
property
(
_get_data
,
_set_data
)
class
_ViewFunc
(
torch
.
autograd
.
Function
):
"""View function
View the NVFP4Tensor using the provided shape.
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
NVFP4Tensor
,
shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
cur_shape
=
tensor
.
shape
if
ctx
is
not
None
:
ctx
.
shape
=
cur_shape
if
shape
is
None
:
return
tensor
# Canonicalize shape
if
not
isinstance
(
shape
,
Iterable
):
shape
=
[
shape
]
elif
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
Iterable
):
shape
=
shape
[
0
]
if
-
1
in
shape
:
shape
=
list
(
shape
)
d_inferred
=
-
math
.
prod
(
cur_shape
)
//
math
.
prod
(
shape
)
for
i
,
d
in
enumerate
(
shape
):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
shape
[
-
1
]
!=
cur_shape
[
-
1
]:
raise
RuntimeError
(
"NVFP4Tensor does not support reshaping inner dimension "
f
"(attempted to reshape dims=
{
tuple
(
tensor
.
shape
)
}
to
{
tuple
(
shape
)
}
)"
)
# Reshape data
new_rowwise_data
=
None
new_columnwise_data
=
None
if
tensor
.
_rowwise_data
is
not
None
:
if
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
list
(
shape
[:
-
1
])
+
[
shape
[
-
1
]
//
2
]
new_rowwise_data
=
tensor
.
_rowwise_data
.
view
(
byte_shape
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
shape
[
-
1
],
math
.
prod
(
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
tensor
.
_columnwise_data
.
view
(
byte_shape
)
# Construct tensor
return
NVFP4Tensor
(
shape
,
tensor
.
dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
,
amax_rowwise
=
tensor
.
_amax_rowwise
,
amax_columnwise
=
tensor
.
_amax_columnwise
,
quantizer
=
tensor
.
_quantizer
,
fp4_dtype
=
tensor
.
_fp4_dtype
,
requires_grad
=
tensor
.
requires_grad
,
)
@
staticmethod
def
backward
(
ctx
,
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
NVFP4Tensor
):
new_rowwise_data
=
None
new_columnwise_data
=
None
if
grad
.
_rowwise_data
is
not
None
:
if
ctx
.
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
ctx
.
shape
}
as byte array."
)
byte_shape
=
list
(
ctx
.
shape
[:
-
1
])
+
[
ctx
.
shape
[
-
1
]
//
2
]
new_rowwise_data
=
grad
.
_rowwise_data
.
view
(
byte_shape
)
if
grad
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
ctx
.
shape
[
-
1
],
math
.
prod
(
ctx
.
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
ctx
.
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
grad
.
_columnwise_data
.
view
(
byte_shape
)
dgrad
=
NVFP4Tensor
(
ctx
.
shape
,
grad
.
dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
grad
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
grad
.
_columnwise_scale_inv
,
amax_rowwise
=
grad
.
_amax_rowwise
,
amax_columnwise
=
grad
.
_amax_columnwise
,
quantizer
=
grad
.
_quantizer
,
fp4_dtype
=
grad
.
_fp4_dtype
,
requires_grad
=
grad
.
requires_grad
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
class
_ReshapeFunc
(
torch
.
autograd
.
Function
):
"""Reshape function
Reshape the NVFP4Tensor using the provided shape.
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
NVFP4Tensor
,
shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
NVFP4Tensor
:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
cur_shape
=
tensor
.
shape
if
ctx
is
not
None
:
ctx
.
shape
=
cur_shape
if
shape
is
None
:
return
tensor
# Canonicalize shape
if
not
isinstance
(
shape
,
Iterable
):
shape
=
[
shape
]
elif
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
Iterable
):
shape
=
shape
[
0
]
if
-
1
in
shape
:
shape
=
list
(
shape
)
d_inferred
=
-
math
.
prod
(
cur_shape
)
//
math
.
prod
(
shape
)
for
i
,
d
in
enumerate
(
shape
):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
shape
[
-
1
]
!=
cur_shape
[
-
1
]:
raise
RuntimeError
(
"NVFP4Tensor does not support reshaping inner dimension "
f
"(attempted to reshape dims=
{
tuple
(
tensor
.
shape
)
}
to
{
tuple
(
shape
)
}
)"
)
# Reshape data
new_rowwise_data
=
None
new_columnwise_data
=
None
if
tensor
.
_rowwise_data
is
not
None
:
if
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
list
(
shape
[:
-
1
])
+
[
shape
[
-
1
]
//
2
]
new_rowwise_data
=
tensor
.
_rowwise_data
.
reshape
(
byte_shape
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
shape
[
-
1
],
math
.
prod
(
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
tensor
.
_columnwise_data
.
reshape
(
byte_shape
)
# Construct tensor
return
NVFP4Tensor
(
shape
,
tensor
.
dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
,
amax_rowwise
=
tensor
.
_amax_rowwise
,
amax_columnwise
=
tensor
.
_amax_columnwise
,
quantizer
=
tensor
.
_quantizer
,
fp4_dtype
=
tensor
.
_fp4_dtype
,
requires_grad
=
tensor
.
requires_grad
,
)
@
staticmethod
def
backward
(
ctx
,
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
NVFP4Tensor
):
new_rowwise_data
=
None
new_columnwise_data
=
None
if
grad
.
_rowwise_data
is
not
None
:
if
ctx
.
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
ctx
.
shape
}
as byte array."
)
byte_shape
=
list
(
ctx
.
shape
[:
-
1
])
+
[
ctx
.
shape
[
-
1
]
//
2
]
new_rowwise_data
=
grad
.
_rowwise_data
.
reshape
(
byte_shape
)
if
grad
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
ctx
.
shape
[
-
1
],
math
.
prod
(
ctx
.
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
ctx
.
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
grad
.
_columnwise_data
.
reshape
(
byte_shape
)
dgrad
=
NVFP4Tensor
(
ctx
.
shape
,
grad
.
dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
grad
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
grad
.
_columnwise_scale_inv
,
amax_rowwise
=
grad
.
_amax_rowwise
,
amax_columnwise
=
grad
.
_amax_columnwise
,
quantizer
=
grad
.
_quantizer
,
fp4_dtype
=
grad
.
_fp4_dtype
,
requires_grad
=
grad
.
requires_grad
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
transformer_engine/pytorch/tensor/quantized_tensor.py
View file @
063ef88d
...
...
@@ -5,7 +5,7 @@
"""Tensor with quantized data"""
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
from
typing
import
Callable
,
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
import
abc
import
copy
import
warnings
...
...
@@ -13,12 +13,11 @@ import warnings
import
torch
from
torch.utils._pytree
import
tree_map
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
class
QuantizedTensor
Bas
e
:
r
"""Base class for all *Tensor
Bas
e classes.
class
QuantizedTensor
Storag
e
:
r
"""Base class for all *Tensor
Storag
e classes.
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
...
...
@@ -26,9 +25,9 @@ class QuantizedTensorBase:
PyTorch's autograd).
When creating a new tensor type X one should create both
XTensor
Bas
e class inheriting from QuantizedTensor
Bas
e and
XTensor inheriting from XTensor
Bas
e and QuantizedTensor.
XTensor
Bas
e should contain all data members needed to
XTensor
Storag
e class inheriting from QuantizedTensor
Storag
e and
XTensor inheriting from XTensor
Storag
e and QuantizedTensor.
XTensor
Storag
e should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
...
...
@@ -59,7 +58,7 @@ class QuantizedTensorBase:
f
"
{
self
.
__class__
.
__name__
}
class does not implement update_usage function"
)
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensor
Bas
e
]:
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensor
Storag
e
]:
"""Prepare the tensor base for saving for backward"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement prepare_for_saving function"
...
...
@@ -73,6 +72,30 @@ class QuantizedTensorBase:
f
"
{
self
.
__class__
.
__name__
}
class does not implement restore_from_saved function"
)
def
_get_quantizer
(
self
)
->
Quantizer
:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if
self
.
_quantizer
is
not
None
:
return
self
.
_quantizer
return
self
.
_build_default_quantizer
()
def
_build_default_quantizer
(
self
)
->
Quantizer
:
"""Build default quantizer for the tensor"""
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
has no quantizer "
"and no default quantizer is available defined in the subclass."
)
def
quantize_
(
self
,
tensor
:
torch
.
Tensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
)
->
QuantizedTensor
:
"""Quantize tensor in-place"""
self
.
_get_quantizer
().
update_quantized
(
tensor
,
self
,
noop_flag
=
noop_flag
)
return
self
def
update_quantizer
(
self
,
quantizer
:
Quantizer
):
"""Update quantizer for the tensor"""
if
self
.
_quantizer
is
None
:
...
...
@@ -83,13 +106,13 @@ class QuantizedTensorBase:
def
prepare_for_saving
(
*
tensors
:
Union
[
torch
.
Tensor
,
QuantizedTensor
Bas
e
],
*
tensors
:
Union
[
torch
.
Tensor
,
QuantizedTensor
Storag
e
],
)
->
Tuple
[
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
list
[
Optional
[
QuantizedTensor
Bas
e
]]
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
list
[
Optional
[
QuantizedTensor
Storag
e
]]
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal Tensor
Bas
e types too."""
the internal
*
Tensor
Storag
e types too."""
tensor_list
,
tensor_objects_list
=
[],
[]
for
tensor
in
tensors
:
...
...
@@ -104,12 +127,12 @@ def prepare_for_saving(
def
restore_from_saved
(
tensors
:
list
[
Optional
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Bas
e
]]],
tensors
:
list
[
Optional
[
Union
[
torch
.
Tensor
,
QuantizedTensor
Storag
e
]]],
saved_tensors
:
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
return_saved_tensors
:
bool
=
False
,
)
->
(
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensor
Bas
e
]]
|
tuple
[
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensor
Bas
e
]],
list
[
Optional
[
torch
.
Tensor
]]]
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensor
Storag
e
]]
|
tuple
[
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensor
Storag
e
]],
list
[
Optional
[
torch
.
Tensor
]]]
):
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects
=
[]
...
...
@@ -178,7 +201,6 @@ class Quantizer(abc.ABC):
")"
)
@
abc
.
abstractmethod
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
...
...
@@ -187,6 +209,9 @@ class Quantizer(abc.ABC):
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
QuantizedTensor
:
"""Quantize tensor in-place"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement update_quantized"
)
def
quantize
(
self
,
...
...
@@ -199,8 +224,14 @@ class Quantizer(abc.ABC):
if
out
is
not
None
:
return
self
.
update_quantized
(
tensor
,
out
)
if
(
not
self
.
internal
)
and
torch
.
is_grad_enabled
():
return
_QuantizeFunc
.
apply
(
tensor
,
self
)
return
_QuantizeFunc
.
forward
(
None
,
tensor
,
self
)
return
_QuantizeFunc
.
apply
(
tensor
,
self
.
quantize_impl
)
return
_QuantizeFunc
.
forward
(
None
,
tensor
,
self
.
quantize_impl
)
def
quantize_impl
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Quantize tensor implementation"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement quantize_impl function"
)
def
multi_quantize
(
self
,
list_of_tensors
):
"""Quantize multiple tensors"""
...
...
@@ -213,7 +244,6 @@ class Quantizer(abc.ABC):
"""Quantize tensor"""
return
self
.
quantize
(
tensor
)
@
abc
.
abstractmethod
def
make_empty
(
self
,
shape
:
Iterable
[
int
],
...
...
@@ -222,8 +252,11 @@ class Quantizer(abc.ABC):
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
QuantizedTensor
:
"""Construct quantized tensor with uninitialized data"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement make_empty function, "
"required for construction of unintialized quantized tensor"
)
@
abc
.
abstractmethod
def
calibrate
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
"""Calibrate quantizer state
...
...
@@ -252,34 +285,47 @@ class Quantizer(abc.ABC):
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Symbolic function for ONNX export"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement onnx_quantize"
)
def
onnx_dequantize
(
self
,
tensor
)
->
torch
.
Tensor
:
"""Symbolic function for ONNX export"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement onnx_dequantize"
)
@
abc
.
abstractmethod
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
"""Returns recipe class that is compatible with this quantizer"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement _get_compatible_recipe"
)
def
supports_only_rowwise_all_gather
(
self
)
->
bool
:
"""Returns True if the quantizer supports only rowwise all-gather"""
return
False
def
is_quantizable
(
self
,
inp
:
torch
.
Tensor
)
->
bool
:
# pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized"""
return
True
class
_QuantizeFunc
(
torch
.
autograd
.
Function
):
"""
Cast to FP8 from other dtype
"""
"""
Quantize tensor
"""
@
staticmethod
def
forward
(
_ctx
:
Optional
[
torch
.
autograd
.
function
.
FunctionCtx
],
# unused
tensor
:
torch
.
Tensor
,
quantize
r
:
Quantizer
,
quantize
_impl
:
Callable
,
)
->
QuantizedTensor
:
# pylint: disable=missing-function-docstring
return
tex
.
quantize
(
tensor
,
quantizer
)
return
quantize
_impl
(
tensor
)
@
staticmethod
def
backward
(
_ctx
:
torch
.
autograd
.
function
.
FunctionCtx
,
grad
:
torch
.
Tensor
# unused
_ctx
:
torch
.
autograd
.
function
.
FunctionCtx
,
# unused
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
...
...
transformer_engine/pytorch/tensor/storage/__init__.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Storage for quantized tensors."""
from
.float8_tensor_storage
import
Float8TensorStorage
# noqa: F401
from
.mxfp8_tensor_storage
import
MXFP8TensorStorage
# noqa: F401
from
.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
# noqa: F401
from
.nvfp4_tensor_storage
import
NVFP4TensorStorage
# noqa: F401
transformer_engine/pytorch/tensor/
_internal
/float8_blockwise_tensor_
bas
e.py
→
transformer_engine/pytorch/tensor/
storage
/float8_blockwise_tensor_
storag
e.py
View file @
063ef88d
...
...
@@ -14,7 +14,7 @@ from transformer_engine_torch import DType as TE_DType
from
transformer_engine_torch
import
Float8BlockScaleTensorFormat
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
..quantized_tensor
import
QuantizedTensor
Bas
e
from
..quantized_tensor
import
QuantizedTensor
Storag
e
from
...constants
import
TE_DType_To_Torch
...
...
@@ -23,7 +23,7 @@ from ..quantized_tensor import Quantizer
from
...utils
import
_empty_tensor
class
Float8BlockwiseQTensor
Bas
e
(
QuantizedTensor
Bas
e
):
class
Float8BlockwiseQTensor
Storag
e
(
QuantizedTensor
Storag
e
):
"""Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
...
...
@@ -54,7 +54,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
*
args
,
**
kwargs
,
):
if
cls
is
Float8BlockwiseQTensor
Bas
e
:
if
cls
is
Float8BlockwiseQTensor
Storag
e
:
instance
=
object
.
__new__
(
cls
)
else
:
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
...
...
@@ -99,7 +99,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def
prepare_for_saving
(
self
,
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
Float8BlockwiseQTensor
Bas
e
]:
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
Float8BlockwiseQTensor
Storag
e
]:
"""
Prepare the tensor base for saving for backward
"""
...
...
@@ -367,7 +367,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
data
=
self
.
dequantize
()
descriptor
=
"columnwise"
return
(
"Float8BlockwiseQTensor
Bas
e("
"Float8BlockwiseQTensor
Storag
e("
f
"fp8_dtype=
{
self
.
_fp8_dtype
}
, "
f
"
{
descriptor
}
_scaled_data=
{
data
}
"
)
...
...
transformer_engine/pytorch/tensor/
_internal
/float8_tensor_
bas
e.py
→
transformer_engine/pytorch/tensor/
storage
/float8_tensor_
storag
e.py
View file @
063ef88d
...
...
@@ -12,7 +12,7 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensor
Bas
e
from
..quantized_tensor
import
QuantizedTensor
Storag
e
from
...constants
import
TE_DType
as
torch_to_transformer_engine_dtype
...
...
@@ -27,7 +27,7 @@ class _FromFloat8Func(torch.autograd.Function):
@
staticmethod
def
forward
(
_ctx
:
Optional
[
torch
.
autograd
.
function
.
FunctionCtx
],
# unused
tensor
:
Float8Tensor
Bas
e
,
tensor
:
Float8Tensor
Storag
e
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
...
...
@@ -52,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function):
return
grad
,
None
class
Float8Tensor
Bas
e
(
QuantizedTensor
Bas
e
):
class
Float8Tensor
Storag
e
(
QuantizedTensor
Storag
e
):
"""Mixin class that holds data attributes of Float8Tensor.
Float8Tensor inherits from the PyTorch tensor class and this mixin
...
...
@@ -81,7 +81,7 @@ class Float8TensorBase(QuantizedTensorBase):
quantizer
:
Optional
[
Quantizer
]
=
None
,
**
kwargs
,
):
if
cls
is
Float8Tensor
Bas
e
:
if
cls
is
Float8Tensor
Storag
e
:
instance
=
object
.
__new__
(
cls
)
else
:
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
...
...
@@ -116,7 +116,7 @@ class Float8TensorBase(QuantizedTensorBase):
"quantizer"
:
self
.
_quantizer
,
}
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensor
Bas
e
]:
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensor
Storag
e
]:
"""Prepare the tensor base for saving for backward"""
tensors
=
[
self
.
_data
,
self
.
_transpose
,
self
.
_scale_inv
]
self
.
_data
=
None
...
...
@@ -163,7 +163,7 @@ class Float8TensorBase(QuantizedTensorBase):
if
out_transpose_shape
[
0
]
!=
shape
[
-
1
]
or
out_transpose_shape
[
1
:]
!=
shape
[:
-
1
]:
out_transpose
=
None
return
Float8Tensor
Bas
e
(
return
Float8Tensor
Storag
e
(
data
=
out_data
,
fp8_scale_inv
=
self
.
_scale_inv
,
fp8_dtype
=
self
.
_fp8_dtype
,
...
...
@@ -173,7 +173,7 @@ class Float8TensorBase(QuantizedTensorBase):
def
__repr__
(
self
):
return
(
"Float8Tensor
Bas
e("
"Float8Tensor
Storag
e("
f
"fp8_dtype=
{
self
.
_fp8_dtype
}
, "
f
"scale_inv=
{
self
.
_scale_inv
.
item
()
}
, "
f
"data=
{
self
.
dequantize
()
}
"
...
...
transformer_engine/pytorch/tensor/
_internal
/mxfp8_tensor_
bas
e.py
→
transformer_engine/pytorch/tensor/
storage
/mxfp8_tensor_
storag
e.py
View file @
063ef88d
...
...
@@ -13,7 +13,7 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensor
Bas
e
from
..quantized_tensor
import
QuantizedTensor
Storag
e
from
...constants
import
TE_DType
as
torch_to_transformer_engine_dtype
...
...
@@ -28,7 +28,7 @@ class _FromMXFP8Func(torch.autograd.Function):
@
staticmethod
def
forward
(
_ctx
:
Optional
[
torch
.
autograd
.
function
.
FunctionCtx
],
# unused
tensor
:
MXFP8Tensor
Bas
e
,
tensor
:
MXFP8Tensor
Storag
e
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
...
...
@@ -49,7 +49,7 @@ class _FromMXFP8Func(torch.autograd.Function):
return
grad
,
None
class
MXFP8Tensor
Bas
e
(
QuantizedTensor
Bas
e
):
class
MXFP8Tensor
Storag
e
(
QuantizedTensor
Storag
e
):
"""Mixin class that holds data attributes of MXFP8Tensor.
MXFP8Tensor inherits from the PyTorch tensor class and this mixin
...
...
@@ -77,7 +77,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
*
args
,
**
kwargs
,
):
if
cls
is
MXFP8Tensor
Bas
e
:
if
cls
is
MXFP8Tensor
Storag
e
:
instance
=
object
.
__new__
(
cls
)
else
:
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
...
...
@@ -112,7 +112,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
"quantizer"
:
self
.
_quantizer
,
}
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
MXFP8Tensor
Bas
e
]:
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
MXFP8Tensor
Storag
e
]:
"""Prepare the tensor base for saving for backward"""
tensors
=
[
self
.
_rowwise_data
,
...
...
@@ -192,7 +192,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
if
cur_columnwise_data
is
not
None
:
new_columnwise_data
=
cur_columnwise_data
.
view
(
*
shape
)
return
MXFP8Tensor
Bas
e
(
return
MXFP8Tensor
Storag
e
(
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
self
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
...
...
@@ -205,7 +205,7 @@ class MXFP8TensorBase(QuantizedTensorBase):
data_rowwise
=
self
.
dequantize
()
return
(
"MXFP8Tensor
Bas
e("
"MXFP8Tensor
Storag
e("
f
"fp8_dtype=
{
self
.
_fp8_dtype
}
, "
f
"rowwise_scaled_data=
{
data_rowwise
}
"
f
"rowwise_scale_inv=
{
self
.
_rowwise_scale_inv
}
, "
...
...
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Mixin class holding data specific for NVFP4Tensor"""
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
functools
import
math
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
warnings
import
torch
# import transformer_engine_torch as tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensorStorage
# from ...constants import TE_DType as torch_to_transformer_engine_dtype
from
..quantized_tensor
import
Quantizer
from
...utils
import
_empty_tensor
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_fp4_e2m1_vals
(
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
"""Values representable in FP4 E2M1 format"""
return
torch
.
tensor
(
[
0.0
,
0.5
,
1.0
,
1.5
,
2.0
,
3.0
,
4.0
,
6.0
,
-
0.0
,
-
0.5
,
-
1.0
,
-
1.5
,
-
2.0
,
-
3.0
,
-
4.0
,
-
6.0
],
device
=
device
,
dtype
=
dtype
,
)
class
_FromNVFP4Func
(
torch
.
autograd
.
Function
):
"""Cast from NVFP4 to other dtype"""
@
staticmethod
def
forward
(
_ctx
:
Optional
[
torch
.
autograd
.
function
.
FunctionCtx
],
# unused
tensor
:
NVFP4TensorStorage
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# Dequantize row-wise data
if
tensor
.
_rowwise_data
is
not
None
:
### TODO(tmoon): Debug dequantize kernel and remove unfused impl
# return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
# Tensor properties
shape
=
list
(
tensor
.
_rowwise_data
.
size
())
shape
[
-
1
]
*=
2
device
=
tensor
.
_rowwise_data
.
device
# Convert FP4E2M1 values to FP32
data
=
tensor
.
_rowwise_data
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
data
=
torch
.
stack
((
data
&
0x0F
,
data
>>
4
),
dim
=-
1
).
reshape
(
shape
)
data
=
_fp4_e2m1_vals
(
device
,
dtype
=
torch
.
float32
)[
data
]
data
=
data
.
to
(
torch
.
float32
).
contiguous
()
# Convert FP8E4M3 block scales to FP32
block_scales
=
tensor
.
_rowwise_scale_inv
block_scales
=
block_scales
.
reshape
(
-
1
,
block_scales
.
size
(
-
1
))
block_scales
=
block_scales
[:
math
.
prod
(
shape
[:
-
1
]),
:
shape
[
-
1
]
//
16
]
block_scales
=
block_scales
.
view
(
torch
.
float8_e4m3fn
).
to
(
torch
.
float32
)
# Convert amax to FP32 tensor scale
tensor_scale
=
tensor
.
_amax_rowwise
/
(
6.0
*
448.0
)
# Scale by FP4E2M1 and FP8E4M3 max
# Apply scales
block_data
=
data
.
view
(
-
1
,
16
)
block_data
*=
tensor_scale
.
view
(())
*
block_scales
.
reshape
(
-
1
,
1
)
return
data
.
to
(
dtype
)
if
tensor
.
_columnwise_data
is
not
None
:
raise
NotImplementedError
(
"Dequantizing column-wise NVFP4 data is not implemented yet!"
)
raise
ValueError
(
"Attempted to dequantize NVFP4 tensor with no data"
)
@
staticmethod
def
backward
(
_ctx
:
torch
.
autograd
.
function
.
FunctionCtx
,
# unused
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return
grad
,
None
class
NVFP4TensorStorage
(
QuantizedTensorStorage
):
"""Mixin class that holds data attributes of NVFP4Tensor.
NVFP4Tensor inherits from the PyTorch tensor class and this mixin
class. If this class is instantiated directly, it has the same
data, lower CPU overhead, and less functionality. It should only
be instantiated directly for performance-critical internal usage.
"""
_rowwise_data
:
Optional
[
torch
.
Tensor
]
_columnwise_data
:
Optional
[
torch
.
Tensor
]
_quantizer
:
Optional
[
Quantizer
]
_rowwise_scale_inv
:
torch
.
Tensor
_columnwise_scale_inv
:
torch
.
Tensor
_fp4_dtype
:
TE_DType
_amax_rowwise
:
torch
.
Tensor
_amax_columnwise
:
torch
.
Tensor
def
__new__
(
cls
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
rowwise_scale_inv
:
torch
.
Tensor
,
columnwise_data
:
Optional
[
torch
.
Tensor
],
columnwise_scale_inv
:
torch
.
Tensor
,
amax_rowwise
:
torch
.
Tensor
,
amax_columnwise
:
torch
.
Tensor
,
fp4_dtype
:
TE_DType
,
quantizer
:
Optional
[
Quantizer
],
*
args
,
**
kwargs
,
):
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
.
_rowwise_data
=
rowwise_data
instance
.
_columnwise_data
=
columnwise_data
instance
.
_fp4_dtype
=
fp4_dtype
instance
.
_quantizer
=
quantizer
.
copy
()
if
quantizer
is
not
None
else
None
instance
.
_rowwise_scale_inv
=
rowwise_scale_inv
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
instance
.
_amax_rowwise
=
amax_rowwise
instance
.
_amax_columnwise
=
amax_columnwise
return
instance
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for
t
in
(
self
.
_rowwise_data
,
self
.
_columnwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_scale_inv
,
self
.
_amax_rowwise
,
self
.
_amax_columnwise
,
):
if
t
is
not
None
:
t
.
data
=
_empty_tensor
()
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
return
{
"rowwise_data"
:
self
.
_rowwise_data
,
"rowwise_scale_inv"
:
self
.
_rowwise_scale_inv
,
"columnwise_data"
:
self
.
_columnwise_data
,
"columnwise_scale_inv"
:
self
.
_columnwise_scale_inv
,
"amax_rowwise"
:
self
.
_amax_rowwise
,
"amax_columnwise"
:
self
.
_amax_columnwise
,
"fp4_dtype"
:
self
.
_fp4_dtype
,
"quantizer"
:
self
.
_quantizer
,
}
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
NVFP4TensorStorage
]:
"""Prepare the tensor base for saving for backward"""
tensors
=
[
self
.
_rowwise_data
,
self
.
_columnwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_scale_inv
,
self
.
_amax_rowwise
,
self
.
_amax_columnwise
,
]
self
.
_rowwise_data
=
None
self
.
_columnwise_data
=
None
self
.
_rowwise_scale_inv
=
None
self
.
_columnwise_scale_inv
=
None
self
.
_amax_rowwise
=
None
self
.
_amax_columnwise
=
None
return
tensors
,
self
def
restore_from_saved
(
self
,
tensors
:
list
[
Optional
[
torch
.
Tensor
]]
)
->
list
[
Optional
[
torch
.
Tensor
]]:
"""Restore the tensor base data from the saved tensors list."""
self
.
_rowwise_data
=
tensors
[
0
]
self
.
_columnwise_data
=
tensors
[
1
]
self
.
_rowwise_scale_inv
=
tensors
[
2
]
self
.
_columnwise_scale_inv
=
tensors
[
3
]
self
.
_amax_rowwise
=
tensors
[
4
]
self
.
_amax_columnwise
=
tensors
[
5
]
return
tensors
[
6
:]
def
get_data_tensors
(
self
):
"""Get this Tensor's data."""
return
self
.
_rowwise_data
,
self
.
_columnwise_data
def
dequantize
(
self
,
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
"""Dequantize to a higher precision."""
return
_FromNVFP4Func
.
forward
(
None
,
self
,
dtype
)
def
size
(
self
,
dim
:
Optional
[
int
]
=
None
)
->
Union
[
torch
.
Size
,
int
]:
# pylint: disable=missing-function-docstring
# Infer tensor shape
shape
=
None
if
self
.
_rowwise_data
is
not
None
:
byte_shape
=
list
(
self
.
_rowwise_data
.
size
())
shape
=
byte_shape
[:
-
1
]
+
[
byte_shape
[
-
1
]
*
2
]
elif
self
.
_columnwise_data
is
not
None
:
warnings
.
warn
(
"Attempting to get shape of NVFP4 tensor with only column-wise data."
)
byte_shape
=
list
(
self
.
_columnwise_data
.
size
())
shape
=
byte_shape
[
1
:
-
1
]
+
[
byte_shape
[
-
1
]
*
2
,
byte_shape
[
0
]]
if
shape
is
None
:
raise
RuntimeError
(
"Attempted to get shape of NVFP4 tensor with no data"
)
# Return shape or dim
if
dim
is
None
:
return
torch
.
Size
(
shape
)
return
shape
[
dim
]
def
view
(
self
,
shape
:
torch
.
Size
):
# pylint: disable=missing-function-docstring
# Return input tensor if view not needed
cur_shape
=
self
.
size
()
if
shape
is
None
or
shape
==
cur_shape
:
return
self
# Canonicalize shape
if
not
isinstance
(
shape
,
Iterable
):
shape
=
[
shape
]
elif
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
Iterable
):
shape
=
shape
[
0
]
if
-
1
in
shape
:
shape
=
list
(
shape
)
d_inferred
=
-
math
.
prod
(
cur_shape
)
//
math
.
prod
(
shape
)
for
i
,
d
in
enumerate
(
shape
):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
shape
[
-
1
]
!=
cur_shape
[
-
1
]:
raise
RuntimeError
(
"NVFP4Tensor does not support reshaping inner dimension "
f
"(attempted to reshape dims=
{
tuple
(
cur_shape
)
}
to
{
tuple
(
shape
)
}
)"
)
# Reshape data
new_rowwise_data
=
None
new_columnwise_data
=
None
if
self
.
_rowwise_data
is
not
None
:
if
shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent row-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
list
(
shape
[:
-
1
])
+
[
shape
[
-
1
]
//
2
]
new_rowwise_data
=
self
.
_rowwise_data
.
view
(
byte_shape
)
if
self
.
_columnwise_data
is
not
None
:
columnwise_shape
=
(
shape
[
-
1
],
math
.
prod
(
shape
[:
-
1
]))
if
columnwise_shape
[
-
1
]
%
2
!=
0
:
raise
ValueError
(
"Cannot represent column-wise data for NVFP4 tensor "
f
"with shape=
{
shape
}
as byte array."
)
byte_shape
=
(
columnwise_shape
[
0
],
columnwise_shape
[
1
]
//
2
)
new_columnwise_data
=
self
.
_columnwise_data
.
view
(
byte_shape
)
# Construct tensor
return
NVFP4TensorStorage
(
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
self
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
self
.
_columnwise_scale_inv
,
amax_rowwise
=
self
.
_amax_rowwise
,
amax_columnwise
=
self
.
_amax_columnwise
,
quantizer
=
self
.
_quantizer
,
fp4_dtype
=
self
.
_fp4_dtype
,
)
def
__repr__
(
self
):
data_rowwise
=
self
.
dequantize
()
return
(
"NVFP4TensorStorage("
f
"rowwise_scaled_data=
{
data_rowwise
}
,"
f
"rowwise_scale_inv=
{
self
.
_rowwise_scale_inv
}
,"
f
"amax_rowwise=
{
self
.
_amax_rowwise
}
,"
f
"amax_columnwise=
{
self
.
_amax_columnwise
}
,"
")"
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
"""
For the NVFP4 format, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if
rowwise_usage
is
None
:
rowwise_usage
=
self
.
_rowwise_data
is
not
None
if
columnwise_usage
is
None
:
columnwise_usage
=
self
.
_columnwise_data
is
not
None
# Update row-scaled data
if
rowwise_usage
:
if
self
.
_rowwise_data
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but NVFP4Tensor is missing row-scaled NVFP4 data"
)
if
self
.
_rowwise_scale_inv
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but NVFP4Tensor is missing row-scaled scale-inverses"
)
if
self
.
_amax_rowwise
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but NVFP4Tensor is missing per tensor"
" row-scaled scale-inverse"
)
else
:
self
.
_rowwise_data
=
None
self
.
_rowwise_scale_inv
=
None
self
.
_amax_rowwise
=
None
# Update column-scaled data
if
columnwise_usage
:
if
self
.
_columnwise_data
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, but NVFP4Tensor is missing column-scaled FP8 data"
)
if
self
.
_columnwise_scale_inv
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, "
"but NVFP4Tensor is missing column-scaled scale-inverses"
)
if
self
.
_amax_columnwise
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, "
"but NVFP4Tensor is missing per tensor column-scaled scale-inverse"
)
else
:
self
.
_columnwise_data
=
None
self
.
_columnwise_scale_inv
=
None
self
.
_amax_columnwise
=
None
transformer_engine/pytorch/tensor/utils.py
View file @
063ef88d
...
...
@@ -4,12 +4,14 @@
"""Helper functions for using fp8 tensors as weights"""
import
os
from
typing
import
Optional
,
Union
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
multi_tensor_scale
,
multi_tensor_compute_scale_and_scale_inv
from
.quantized_tensor
import
QuantizedTensor
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
QuantizedTensorStorage
from
.float8_tensor
import
Float8Tensor
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8Quantizer
from
.float8_blockwise_tensor
import
Float8BlockwiseQTensor
,
Float8BlockQuantizer
...
...
@@ -455,3 +457,20 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
tex
.
fp8_block_scaling_partial_cast
(
master_weight
,
model_weight_fragment
,
scale
,
h
,
w
,
start_offset
,
block_len
,
fp8_dtype
)
def
is_experimental
(
x
:
Optional
[
Union
[
Quantizer
,
QuantizedTensorStorage
]]
=
None
)
->
bool
:
"""Check if an environment or object is using experimental Kitchen middleware.
Returns False if x is a torch.Tensor.
"""
# Detect if the environment is experimental
if
x
is
None
:
return
int
(
os
.
getenv
(
"QAT_PARAMS"
,
"0"
))
>
0
# Detect if the object is experimental
if
isinstance
(
x
,
torch
.
Tensor
):
return
False
if
not
isinstance
(
x
,
(
Quantizer
,
QuantizedTensorStorage
)):
raise
AssertionError
(
"Object must be a Quantizer or QuantizedTensorStorage instance"
)
return
hasattr
(
x
,
"experimental"
)
and
x
.
experimental
transformer_engine/pytorch/transformer.py
View file @
063ef88d
...
...
@@ -191,6 +191,17 @@ class TransformerLayer(torch.nn.Module):
and `DotProductAttention` modules.
name: str, default = `None`
name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters
----------------------
...
...
@@ -306,6 +317,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type
:
Optional
[
str
]
=
None
,
qk_norm_eps
:
float
=
1e-6
,
qk_norm_before_rope
:
bool
=
False
,
softmax_type
:
str
=
"vanilla"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -362,6 +374,7 @@ class TransformerLayer(torch.nn.Module):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
attn_input_format
=
attn_input_format
self
.
softmax_type
=
softmax_type
self
.
name
=
name
...
...
@@ -397,6 +410,7 @@ class TransformerLayer(torch.nn.Module):
"qkv_format"
:
self
.
attn_input_format
,
"seq_length"
:
seq_length
,
"micro_batch_size"
:
micro_batch_size
,
"softmax_type"
:
self
.
softmax_type
,
}
self
.
self_attention
=
MultiheadAttention
(
...
...
transformer_engine/pytorch/triton/pad.py
0 → 100644
View file @
063ef88d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""NVFP4 padding kernels
TODO(ksivamani): Documentation
"""
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
2
),
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"out_dim0"
,
"out_dim1"
],
)
@
triton
.
jit
def
zero_pad_kernel
(
inp_ptr
,
out_ptr
,
in_dim0
:
tl
.
constexpr
,
in_dim1
:
tl
.
constexpr
,
out_dim0
:
tl
.
constexpr
,
out_dim1
:
tl
.
constexpr
,
in_s0
,
in_s1
,
out_s0
,
out_s1
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""Pads a tensor assuming it's a columnwise scaling inverse."""
# tile over OUTPUT coordinates
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# output rows
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# output cols
om
=
offs_m
[:,
None
]
on
=
offs_n
[
None
,
:]
# edge masking for output
out_mask
=
(
om
<
out_dim0
)
&
(
on
<
out_dim1
)
# valid input region is simply top-left (no offsets)
in_mask
=
(
om
<
in_dim0
)
&
(
on
<
in_dim1
)
# load valid input, else zero (masked load touches memory only where True)
x
=
tl
.
load
(
inp_ptr
+
om
*
in_s0
+
on
*
in_s1
,
mask
=
in_mask
,
other
=
0
)
# store to output (only within bounds of the output tile)
tl
.
store
(
out_ptr
+
om
*
out_s0
+
on
*
out_s1
,
x
,
mask
=
out_mask
)
def
pad_columnwise_scale_inv
(
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Pads a tensor assuming it's a columnwise scaling inverse."""
assert
inp
.
ndim
==
2
dim0
,
dim1
=
inp
.
shape
pad_x
=
(
128
-
dim0
%
128
)
%
128
pad_y
=
(
4
-
dim1
%
4
)
%
4
out_x
=
dim0
+
pad_x
out_y
=
dim1
+
pad_y
out
=
torch
.
empty
((
out_x
,
out_y
),
device
=
inp
.
device
,
dtype
=
inp
.
dtype
)
in_s0
,
in_s1
=
inp
.
stride
()
out_s0
,
out_s1
=
out
.
stride
()
BLOCK_M
,
BLOCK_N
=
128
,
128
grid
=
(
triton
.
cdiv
(
out_x
,
BLOCK_M
),
triton
.
cdiv
(
out_y
,
BLOCK_N
))
zero_pad_kernel
[
grid
](
inp
,
out
,
dim0
,
dim1
,
out_x
,
out_y
,
in_s0
,
in_s1
,
out_s0
,
out_s1
,
)
return
out
transformer_engine/pytorch/triton/permutation.py
View file @
063ef88d
...
...
@@ -324,7 +324,8 @@ def _permute_kernel(
pid_h
=
tl
.
program_id
(
1
)
cur_off
=
pid_h
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
cur_off
<
hidden_size
input_off
=
pid_t
*
stride_input_token
+
cur_off
*
stride_input_hidden
src_row
=
pid_t
.
to
(
tl
.
int64
)
input_off
=
src_row
*
stride_input_token
+
cur_off
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
if
PERMUTE_SCALE
:
mask_scale
=
cur_off
<
scale_hidden_dim
...
...
@@ -338,7 +339,7 @@ def _permute_kernel(
for
idx
in
tl
.
range
(
n_routed
):
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
)
)
.
to
(
tl
.
int64
)
output_off
=
dst_row
*
stride_output_token
+
cur_off
*
stride_output_hidden
if
PERMUTE_SCALE
:
permuted_scale_off
=
(
...
...
@@ -519,7 +520,7 @@ def _unpermute_kernel(
for
idx
in
tl
.
range
(
n_routed
):
src_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
)
)
.
to
(
tl
.
int64
)
input_off
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
...
...
@@ -550,7 +551,8 @@ def _unpermute_kernel(
prob
=
tl
.
load
(
permuted_probs_ptr
+
permuted_prob_off
)
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
prob
)
accumulator
=
accumulator
.
to
(
data_type
)
output_off
=
pid_t
*
stride_output_token
+
current_offset
*
stride_output_hidden
dst_row
=
pid_t
.
to
(
tl
.
int64
)
output_off
=
dst_row
*
stride_output_token
+
current_offset
*
stride_output_hidden
tl
.
store
(
output_ptr
+
output_off
,
accumulator
,
mask
=
mask
)
...
...
@@ -681,7 +683,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
for
idx
in
tl
.
range
(
n_routed
):
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
)
)
.
to
(
tl
.
int64
)
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
...
...
@@ -692,8 +694,10 @@ def _unpermute_bwd_with_merging_probs_kernel(
while
current_start
<
hidden_size
:
current_offset
=
current_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
src_row
=
pid
.
to
(
tl
.
int64
)
input_off
=
(
pid
*
stride_fwd_output_grad_token
+
current_offset
*
stride_fwd_output_grad_hidden
src_row
*
stride_fwd_output_grad_token
+
current_offset
*
stride_fwd_output_grad_hidden
)
inp
=
tl
.
load
(
fwd_output_grad_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
...
...
@@ -902,11 +906,11 @@ def _sort_chunks_by_map_kernel(
pid_t
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
if
FORWARD
:
src_row
=
pid_t
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
)
src_row
=
pid_t
.
to
(
tl
.
int64
)
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
)
.
to
(
tl
.
int64
)
else
:
src_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
)
dst_row
=
pid_t
src_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
)
.
to
(
tl
.
int64
)
dst_row
=
pid_t
.
to
(
tl
.
int64
)
current_offset
=
pid_h
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
input_offsets
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
...
...
transformer_engine/pytorch/utils.py
View file @
063ef88d
...
...
@@ -10,10 +10,14 @@ import os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
.
import
torch_version
from
.tensor.quantized_tensor
import
Quantizer
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
__all__
=
[
"get_device_compute_capability"
,
"get_cudnn_version"
,
"is_bf16_available"
]
def
requires_grad
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
"""Check if any of the given tensors require gradient."""
for
tensor
in
tensors
:
...
...
@@ -182,7 +186,7 @@ def combine_tensors(
num_tensors
=
len
(
tensors
)
new_shape
=
list
(
tensors
[
0
].
shape
)
new_shape
.
insert
(
dim
,
num_tensors
)
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.
tensor.
float8_tensor
import
Float8Tensor
if
isinstance
(
tensors
[
0
],
Float8Tensor
):
new_stride
=
list
(
tensors
[
0
].
_data
.
stride
())
...
...
@@ -222,14 +226,16 @@ class SplitAlongDim(torch.autograd.Function):
# pylint: disable=missing-function-docstring
ctx
.
split_dim
=
split_dim
ctx
.
split_size_or_sections
=
split_size_or_sections
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor._internal.float8_tensor_base
import
Float8TensorBase
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.storage.float8_tensor_storage
import
(
Float8TensorStorage
,
)
if
isinstance
(
mixed_x_layer
,
Float8Tensor
Bas
e
)
and
not
isinstance
(
if
isinstance
(
mixed_x_layer
,
Float8Tensor
Storag
e
)
and
not
isinstance
(
mixed_x_layer
,
Float8Tensor
):
return
tuple
(
Float8Tensor
Bas
e
(
Float8Tensor
Storag
e
(
fp8_scale_inv
=
mixed_x_layer
.
_scale_inv
,
fp8_dtype
=
mixed_x_layer
.
_fp8_dtype
,
data
=
x
.
squeeze
(
split_dim
)
if
squeeze
else
x
,
...
...
@@ -274,7 +280,7 @@ class SplitAlongDim(torch.autograd.Function):
split_sizes
=
[
ctx
.
split_size_or_sections
]
*
len
(
grad_outputs
)
dims
=
len
(
grad_outputs
[
0
].
shape
)
split_dim
=
(
ctx
.
split_dim
+
dims
)
%
dims
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.
tensor.
float8_tensor
import
Float8Tensor
if
isinstance
(
grad_outputs
[
0
],
Float8Tensor
):
noop_ok
=
True
...
...
@@ -454,14 +460,23 @@ if IS_HIP_EXTENSION:
import
re
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
assert_dim_for_all_gather
(
tensor
:
torch
.
Tensor
,
with_all_gather
:
bool
,
quantizer
:
Quantizer
)
->
None
:
"""Assert that tensor dimensions are supported for all-gather"""
if
with_all_gather
:
assert
quantizer
.
is_quantizable
(
tensor
),
(
"All-gather requires quantizable tensor for quantizer "
+
quantizer
.
__class__
.
__name__
)
def
is_bf16_compatible
()
->
None
:
def
is_bf16_compatible
()
->
bool
:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
"""
if
IS_HIP_EXTENSION
:
# only MI200 and MI300 machines support bf16
if
get_device_compute_capability
()
=
=
(
9
,
4
)
or
is_mi200
()
or
is_K100_AI
()
or
is_BW
():
if
get_device_compute_capability
()
>
=
(
9
,
4
)
or
is_mi200
()
or
is_K100_AI
()
or
is_BW
():
return
True
else
:
return
False
...
...
@@ -469,6 +484,29 @@ def is_bf16_compatible() -> None:
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
def
is_bf16_available
(
return_reason
:
bool
=
False
)
->
Union
[
bool
,
Tuple
[
bool
,
str
]]:
"""
Determine whether bfloat16 (BF16) computation is supported on the current device.
Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating BF16 availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when BF16 is not available. When BF16 is available,
the reason will be an empty string.
"""
available
=
is_bf16_compatible
()
if
not
return_reason
:
return
available
reason
=
(
""
if
available
else
"BF16 support requires a GPU with compute capability 8.0 or higher."
)
return
available
,
reason
@
functools
.
lru_cache
(
maxsize
=
None
)
def
is_non_tn_fp8_gemm_supported
(
is_blockwise
:
Optional
[
bool
]
=
False
)
->
bool
:
"""Checks whether the device supports
...
...
@@ -486,6 +524,8 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
@
functools
.
lru_cache
(
maxsize
=
None
)
def
get_cudnn_version
()
->
Tuple
[
int
,
int
,
int
]:
"""Runtime cuDNN version (major, minor, patch)"""
import
transformer_engine.pytorch.cpp_extensions
as
ext
# ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out
if
IS_HIP_EXTENSION
:
return
(
99
,
0
,
0
)
...
...
Prev
1
…
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment