Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
793 additions
and
453 deletions
+793
-453
transformer_engine/pytorch/optimizers/multi_tensor_apply.py
transformer_engine/pytorch/optimizers/multi_tensor_apply.py
+1
-1
transformer_engine/pytorch/permutation.py
transformer_engine/pytorch/permutation.py
+115
-37
transformer_engine/pytorch/pyproject.toml
transformer_engine/pytorch/pyproject.toml
+10
-0
transformer_engine/pytorch/quantization.py
transformer_engine/pytorch/quantization.py
+21
-12
transformer_engine/pytorch/quantized_tensor.py
transformer_engine/pytorch/quantized_tensor.py
+21
-25
transformer_engine/pytorch/router.py
transformer_engine/pytorch/router.py
+23
-23
transformer_engine/pytorch/setup.py
transformer_engine/pytorch/setup.py
+18
-6
transformer_engine/pytorch/tensor/__init__.py
transformer_engine/pytorch/tensor/__init__.py
+1
-1
transformer_engine/pytorch/tensor/_quantization_helpers.py
transformer_engine/pytorch/tensor/_quantization_helpers.py
+1
-1
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+95
-144
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+93
-29
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+125
-82
transformer_engine/pytorch/tensor/nvfp4_tensor.py
transformer_engine/pytorch/tensor/nvfp4_tensor.py
+60
-25
transformer_engine/pytorch/tensor/storage/__init__.py
transformer_engine/pytorch/tensor/storage/__init__.py
+1
-1
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py
...pytorch/tensor/storage/float8_blockwise_tensor_storage.py
+11
-51
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
...er_engine/pytorch/tensor/storage/float8_tensor_storage.py
+1
-1
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
...mer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
+20
-6
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
...mer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
+21
-3
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+140
-5
transformer_engine/pytorch/torch_version.py
transformer_engine/pytorch/torch_version.py
+15
-0
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
transformer_engine/pytorch/optimizers/multi_tensor_apply.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/permutation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""MoE Permutaion API"""
"""MoE Permuta
t
ion API"""
import
warnings
from
typing
import
Optional
,
Tuple
import
torch
...
...
@@ -191,6 +191,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
routing_map
:
torch
.
Tensor
,
num_out_tokens
:
int
,
probs
:
torch
.
Tensor
,
pad_offsets
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# pylint: disable=missing-function-docstring
if
not
inp
.
numel
():
...
...
@@ -201,6 +202,8 @@ class _moe_permute_mask_map(torch.autograd.Function):
assert
routing_map
.
is_cuda
,
"TransformerEngine needs CUDA."
if
probs
is
not
None
:
assert
probs
.
is_cuda
,
"TransformerEngine needs CUDA."
if
pad_offsets
is
not
None
:
assert
pad_offsets
.
is_cuda
,
"TransformerEngine needs CUDA."
assert
inp
.
size
(
0
)
==
routing_map
.
size
(
0
),
"Permute not possible"
num_tokens
,
hidden_size
=
inp
.
size
()
...
...
@@ -250,6 +253,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map
,
probs
,
fp8_scale
,
pad_offsets
,
num_tokens
,
num_experts
,
num_out_tokens
,
...
...
@@ -290,9 +294,10 @@ class _moe_permute_mask_map(torch.autograd.Function):
columnwise_scale_inv
=
None
,
quantizer
=
None
,
requires_grad
=
output
.
requires_grad
,
with_gemm_swizzled_scales
=
False
,
)
ctx
.
save_for_backward
(
row_id_map
)
ctx
.
save_for_backward
(
row_id_map
,
pad_offsets
)
ctx
.
num_experts
=
num_experts
ctx
.
num_tokens
=
num_tokens
ctx
.
hidden_size
=
hidden_size
...
...
@@ -307,12 +312,12 @@ class _moe_permute_mask_map(torch.autograd.Function):
)
->
Tuple
[
torch
.
Tensor
,
...]:
# pylint: disable=missing-function-docstring
if
not
permuted_act_grad
.
numel
():
return
permuted_act_grad
,
None
,
None
,
ctx
.
probs
return
permuted_act_grad
,
None
,
None
,
ctx
.
probs
,
None
act_grad
=
None
probs_grad
=
None
if
ctx
.
needs_input_grad
[
0
]:
(
row_id_map
,
)
=
ctx
.
saved_tensors
row_id_map
,
pad_offsets
=
ctx
.
saved_tensors
assert
not
isinstance
(
permuted_act_grad
,
QuantizedTensor
),
"The backward of moe_permute does not support FP8."
...
...
@@ -321,13 +326,14 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map
,
None
,
permuted_probs_grad
,
pad_offsets
,
ctx
.
num_tokens
,
ctx
.
num_experts
,
ctx
.
hidden_size
,
)
if
not
ctx
.
needs_input_grad
[
3
]:
probs_grad
=
None
return
act_grad
,
None
,
None
,
probs_grad
return
act_grad
,
None
,
None
,
probs_grad
,
None
class
_moe_unpermute_mask_map
(
torch
.
autograd
.
Function
):
...
...
@@ -340,6 +346,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map
:
torch
.
Tensor
,
merging_probs
:
Optional
[
torch
.
Tensor
],
restore_shape
:
Optional
[
torch
.
Size
],
pad_offsets
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
if
not
inp
.
numel
():
...
...
@@ -358,6 +365,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
# Device check
assert
inp
.
is_cuda
,
"TransformerEngine needs CUDA."
assert
row_id_map
.
is_cuda
,
"TransformerEngine needs CUDA."
if
pad_offsets
is
not
None
:
assert
pad_offsets
.
is_cuda
,
"TransformerEngine needs CUDA."
assert
not
isinstance
(
inp
,
QuantizedTensor
...
...
@@ -367,15 +376,16 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map
,
merging_probs
,
None
,
pad_offsets
,
num_tokens
,
num_experts
,
hidden_size
,
)
if
with_probs
:
ctx
.
save_for_backward
(
inp
,
row_id_map
,
merging_probs
)
ctx
.
save_for_backward
(
inp
,
row_id_map
,
merging_probs
,
pad_offsets
)
else
:
ctx
.
save_for_backward
(
row_id_map
)
ctx
.
save_for_backward
(
row_id_map
,
pad_offsets
)
ctx
.
num_experts
=
num_experts
ctx
.
num_tokens
=
num_tokens
ctx
.
num_permuted_tokens
=
inp
.
size
(
0
)
...
...
@@ -387,15 +397,15 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
def
backward
(
ctx
,
unpermuted_act_grad
):
# pylint: disable=missing-function-docstring
if
not
unpermuted_act_grad
.
numel
():
return
unpermuted_act_grad
,
None
,
ctx
.
merging_probs
,
None
return
unpermuted_act_grad
,
None
,
ctx
.
merging_probs
,
None
,
None
act_grad
=
None
probs_grad
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
with_probs
:
fwd_input
,
row_id_map
,
merging_probs
=
ctx
.
saved_tensors
fwd_input
,
row_id_map
,
merging_probs
,
pad_offsets
=
ctx
.
saved_tensors
else
:
(
row_id_map
,
)
=
ctx
.
saved_tensors
row_id_map
,
pad_offsets
=
ctx
.
saved_tensors
fp8
=
isinstance
(
unpermuted_act_grad
,
QuantizedTensor
)
per_tensor_recipe
=
isinstance
(
unpermuted_act_grad
,
Float8Tensor
)
...
...
@@ -441,6 +451,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map
,
fwd_input
,
merging_probs
,
pad_offsets
,
ctx
.
num_tokens
,
ctx
.
num_experts
,
ctx
.
num_permuted_tokens
,
...
...
@@ -453,6 +464,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map
,
None
,
fp8_scale
,
pad_offsets
,
ctx
.
num_tokens
,
ctx
.
num_experts
,
ctx
.
num_permuted_tokens
,
...
...
@@ -493,11 +505,12 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
columnwise_scale_inv
=
None
,
quantizer
=
None
,
requires_grad
=
act_grad
.
requires_grad
,
with_gemm_swizzled_scales
=
False
,
)
if
not
ctx
.
needs_input_grad
[
2
]:
probs_grad
=
None
return
act_grad
,
None
,
probs_grad
,
None
return
act_grad
,
None
,
probs_grad
,
None
,
None
def
moe_permute
(
...
...
@@ -514,22 +527,22 @@ def moe_permute(
Parameters
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
routing_map: torch.Tensor
routing_map
: torch.Tensor
The token to expert mapping tensor.
If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
The values in it are the routed expert indices.
num_out_tokens: int, default = -1
num_out_tokens
: int, default = -1
The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped.
max_token_num: int, default = -1
max_token_num
: int, default = -1
The maximum number of tokens, used for workspace allocation.
By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator.
map_type: str, default = 'mask'
map_type
: str, default = 'mask'
Type of the routing map tensor.
Options are: 'mask', 'index'.
Refer to `routing_map` for more details.
...
...
@@ -537,7 +550,9 @@ def moe_permute(
if
map_type
==
"index"
:
return
_moe_permute_index_map
.
apply
(
inp
,
routing_map
,
num_out_tokens
,
max_token_num
)
if
map_type
==
"mask"
:
output
,
row_id_map
,
_
=
_moe_permute_mask_map
.
apply
(
inp
,
routing_map
,
num_out_tokens
,
None
)
output
,
row_id_map
,
_
=
_moe_permute_mask_map
.
apply
(
inp
,
routing_map
,
num_out_tokens
,
None
,
None
)
return
output
,
row_id_map
raise
ValueError
(
"map_type should be one of 'mask' or 'index'"
)
...
...
@@ -556,25 +571,81 @@ def moe_permute_with_probs(
Parameters
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
probs
: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map.
routing_map: torch.Tensor
routing_map
: torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
num_out_tokens: int, default = -1
num_out_tokens
: int, default = -1
The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped.
"""
output
,
row_id_map
,
permuted_probs
=
_moe_permute_mask_map
.
apply
(
inp
,
routing_map
,
num_out_tokens
,
probs
inp
,
routing_map
,
num_out_tokens
,
probs
,
None
)
return
output
,
permuted_probs
,
row_id_map
def
moe_permute_and_pad_with_probs
(
inp
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
align_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""
Permute the tokens and probs based on the routing_map.
Token with the same index will be grouped together.
Tokens with the same designated expert will be grouped together.
The routing_map indicates which experts were selected by each token.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map.
routing_map: torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
tokens_per_expert : torch.Tensor
Tensor of shape `[num_experts]` containing actual token counts per expert.
align_size : int
the alignment size for the input tensor.
"""
assert
(
tokens_per_expert
is
not
None
),
"tokens_per_expert must be provided to the fused permute padding function."
assert
align_size
>
0
,
f
"align_size must be positive, got
{
align_size
}
"
# Ensure tokens_per_expert is on the same device as input to avoid device transfers
if
tokens_per_expert
.
device
!=
inp
.
device
:
tokens_per_expert
=
tokens_per_expert
.
to
(
inp
.
device
)
# Calculate aligned token counts per expert
target_tokens_per_expert
=
(
torch
.
ceil
(
tokens_per_expert
/
align_size
)
*
align_size
).
long
()
if
torch
.
equal
(
tokens_per_expert
,
target_tokens_per_expert
):
pad_offsets
=
None
else
:
pad_lengths
=
target_tokens_per_expert
-
tokens_per_expert
cum_pad
=
torch
.
cumsum
(
pad_lengths
,
dim
=
0
)
pad_offsets
=
torch
.
cat
(
[
torch
.
zeros
(
1
,
dtype
=
cum_pad
.
dtype
,
device
=
inp
.
device
),
cum_pad
[:
-
1
]]
)
output
,
row_id_map
,
permuted_probs
=
_moe_permute_mask_map
.
apply
(
inp
,
routing_map
,
target_tokens_per_expert
.
sum
().
item
(),
probs
,
pad_offsets
)
return
output
,
permuted_probs
,
row_id_map
,
pad_offsets
,
target_tokens_per_expert
def
moe_unpermute
(
inp
:
torch
.
Tensor
,
row_id_map
:
torch
.
Tensor
,
...
...
@@ -582,6 +653,7 @@ def moe_unpermute(
restore_shape
:
Optional
[
torch
.
Size
]
=
None
,
map_type
:
str
=
"mask"
,
probs
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
...
...
@@ -589,22 +661,26 @@ def moe_unpermute(
Parameters
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
row_id_map: torch.Tensor
row_id_map
: torch.Tensor
The tensor of a mapping table for sorted indices used to unpermute the tokens,
which is the second output tensor of `Permute`.
merging_probs: torch.Tensor, default = None
merging_probs
: torch.Tensor, default = None
The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Size, default = None
restore_shape
: torch.Size, default = None
The output shape after the unpermute operation.
map_type: str, default = 'mask'
map_type
: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
Options are: 'mask', 'index'.
probs: torch.Tensor, default = None
probs
: torch.Tensor, default = None
Renamed to merging_probs. Keep for backward compatibility.
pad_offsets : torch.Tensor, default = None
Tensor of per-expert cumulative padding offsets used to remove padding added
during permutation. This is the fourth output of `moe_permute_and_pad_with_probs`
and is required when unpermuting padded outputs.
"""
if
probs
is
not
None
:
if
merging_probs
is
not
None
:
...
...
@@ -616,7 +692,9 @@ def moe_unpermute(
if
map_type
==
"index"
:
return
_moe_unpermute_index_map
.
apply
(
inp
,
row_id_map
,
merging_probs
)
if
map_type
==
"mask"
:
return
_moe_unpermute_mask_map
.
apply
(
inp
,
row_id_map
,
merging_probs
,
restore_shape
)
return
_moe_unpermute_mask_map
.
apply
(
inp
,
row_id_map
,
merging_probs
,
restore_shape
,
pad_offsets
)
raise
ValueError
(
"map_type should be one of 'mask' or 'index'"
)
...
...
@@ -733,11 +811,11 @@ def moe_sort_chunks_by_index(
Parameters
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
split_sizes: torch.Tensor
split_sizes
: torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor
sorted_indices
: torch.Tensor
Chunk indices used to permute the chunks.
"""
output
,
_
=
_moe_chunk_sort
.
apply
(
inp
,
split_sizes
,
sorted_index
,
None
)
...
...
@@ -757,15 +835,15 @@ def moe_sort_chunks_by_index_with_probs(
Parameters
----------
inp: torch.Tensor
inp
: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
probs
: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens]. It will be permuted with the tokens according to
the split_sizes and sorted_indices.
split_sizes: torch.Tensor
split_sizes
: torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor
sorted_indices
: torch.Tensor
Chunk indices used to permute the chunks.
"""
output
,
permuted_probs
=
_moe_chunk_sort
.
apply
(
inp
,
split_sizes
,
sorted_index
,
probs
)
...
...
transformer_engine/pytorch/pyproject.toml
0 → 100755
View file @
0d874a4e
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires
=
[
"setuptools>=61.0"
,
"pip"
,
"torch>=2.1"
]
# Use legacy backend to import local packages in setup.py
build-backend
=
"setuptools.build_meta:__legacy__"
transformer_engine/pytorch/quantization.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -26,7 +26,6 @@ from transformer_engine.common.recipe import (
NVFP4BlockScaling
,
CustomRecipe
,
)
from
.constants
import
dist_group_type
from
.utils
import
(
get_device_compute_capability
,
is_gfx928
,
is_gfx936
,
is_gfx938
)
from
.jit
import
jit_fuser
...
...
@@ -43,6 +42,7 @@ __all__ = [
"is_fp8_block_scaling_available"
,
"is_nvfp4_available"
,
"get_default_recipe"
,
"get_align_size_for_quantization"
,
]
@
functools
.
lru_cache
(
maxsize
=
None
)
...
...
@@ -131,6 +131,15 @@ def get_default_recipe() -> Recipe:
return
get_default_fp8_recipe
()
def
get_align_size_for_quantization
(
recipe
:
Recipe
)
->
int
:
"""Get the alignment size for quantization."""
if
recipe
.
mxfp8
():
return
32
if
recipe
.
nvfp4
():
return
128
return
16
def
get_fp8_torch_dtype
(
fp8_recipe
:
Recipe
,
fprop_tensor
:
bool
=
True
)
->
torch
.
dtype
:
"""Get fp8 data type according to recipe and tensor"""
if
fp8_recipe
.
fp8_format
==
Format
.
E4M3
or
(
...
...
@@ -685,7 +694,7 @@ def fp8_model_init(
.. warning::
fp8_model_init is deprecated and will be removed in a future release. Use
quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead.
``
quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...)
``
instead.
"""
...
...
@@ -730,7 +739,7 @@ def quantized_model_init(
Parameters
----------
enabled: bool, default =
`
True
`
enabled
: bool, default = True
when enabled, Transformer Engine modules created inside this `quantized_model_init`
region will hold only quantized copies of its parameters, as opposed to the default
behavior where both higher precision and quantized copies are present. Setting this
...
...
@@ -741,9 +750,9 @@ def quantized_model_init(
precision copies of weights are already present in the optimizer.
* inference, where only the quantized copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default =
`
None
`
recipe
: transformer_engine.common.recipe.Recipe, default = None
Recipe used to create the parameters. If left to None, it uses the default recipe.
preserve_high_precision_init_val: bool, default =
`
False
`
preserve_high_precision_init_val
: bool, default = False
when enabled, store the high precision tensor used to initialize quantized parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to quantized parameters to get/clear this high
...
...
@@ -780,8 +789,8 @@ def fp8_autocast(
"""
.. warning::
fp8_autocast is deprecated and will be removed in a future release.
Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead.
``
fp8_autocast
``
is deprecated and will be removed in a future release.
Use
``
autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...)
``
instead.
"""
...
...
@@ -835,16 +844,16 @@ def autocast(
Parameters
----------
enabled: bool, default =
`
True
`
enabled
: bool, default = True
whether or not to enable low precision quantization (FP8/FP4).
calibrating: bool, default =
`
False
`
calibrating
: bool, default = False
calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training
using a higher precision.
recipe: recipe.Recipe, default =
`
None
`
recipe
: recipe.Recipe, default = None
recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default =
`
None
`
amax_reduction_group
: torch._C._distributed_c10d.ProcessGroup, default = None
distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step.
"""
...
...
transformer_engine/pytorch/quantized_tensor.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -7,7 +7,6 @@
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
import
abc
import
copy
import
warnings
import
math
...
...
@@ -21,14 +20,9 @@ from transformer_engine.pytorch.tensor._quantization_helpers import (
_stride_from_shape
,
)
_quantized_tensor_cpu_supported_ops
=
(
torch
.
ops
.
aten
.
empty_like
.
default
,
torch
.
ops
.
aten
.
copy_
.
default
,
)
class
QuantizedTensorStorage
:
r
"""Base class for all
*
TensorStorage classes.
r
"""Base class for all TensorStorage classes.
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
...
...
@@ -55,11 +49,11 @@ class QuantizedTensorStorage:
Parameters
----------
rowwise_usage : Optional[bool[, default =
`
None
`
rowwise_usage : Optional[bool[, default = None
Whether to create or keep the data needed for using the tensor
in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
preserves the original value in the tensor.
columnwise_usage : Optional[bool], default =
`
None
`
columnwise_usage : Optional[bool], default = None
Whether to create or keep the data needed for using the tensor
in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
`None` preserves the original value in the tensor.
...
...
@@ -129,7 +123,7 @@ def prepare_for_saving(
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal
*
TensorStorage types too."""
the internal TensorStorage types too."""
tensor_list
,
tensor_objects_list
=
[],
[]
for
tensor
in
tensors
:
...
...
@@ -205,10 +199,21 @@ class Quantizer(abc.ABC):
"""
internal
:
bool
"""Whether to solely optimize for matrix multiplication
The resulting quantized tensors are not guaranteed to support any
operation other than matrix multiplication. Use with care since
this is likely to break communication, checkpointing, and many
other features.
"""
optimize_for_gemm
:
bool
def
__init__
(
self
,
*
,
rowwise
:
bool
,
columnwise
:
bool
)
->
None
:
self
.
rowwise_usage
=
rowwise
self
.
columnwise_usage
=
columnwise
self
.
internal
=
False
self
.
optimize_for_gemm
=
False
def
__repr__
(
self
):
return
(
...
...
@@ -297,10 +302,6 @@ class Quantizer(abc.ABC):
if
columnwise
is
not
None
:
self
.
columnwise_usage
=
columnwise
def
copy
(
self
)
->
Quantizer
:
"""Create shallow copy"""
return
copy
.
copy
(
self
)
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Symbolic function for ONNX export"""
raise
NotImplementedError
(
...
...
@@ -324,7 +325,11 @@ class Quantizer(abc.ABC):
return
False
def
is_quantizable
(
self
,
inp
:
torch
.
Tensor
)
->
bool
:
# pylint: disable=unused-argument
"""Returns whether or not given tensor can be quantized"""
"""Whether tensor supports quantized all-gather
Consider a less misleading function name.
"""
return
True
def
get_usages
(
self
)
->
Dict
[
str
,
bool
]:
...
...
@@ -544,15 +549,6 @@ class QuantizedTensor(torch.Tensor):
if
kwargs
is
None
:
kwargs
=
{}
def
check_if_cpu
(
arg
):
if
isinstance
(
cls
,
QuantizedTensor
)
and
arg
.
device
.
type
==
"cpu"
:
assert
(
func
in
_quantized_tensor_cpu_supported_ops
),
f
"QuantizedTensor on CPU does not support this operation:
{
func
}
"
return
arg
args
=
tree_map
(
check_if_cpu
,
args
)
# Do not force the QuantizedTensor type on the returned tensor
return
torch
.
_C
.
_disabled_torch_function_impl
(
func
,
types
,
args
,
kwargs
)
...
...
transformer_engine/pytorch/router.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
...
...
@@ -92,24 +92,24 @@ def fused_topk_with_score_function(
Fused topk with score function router.
Parameters
----------
logits: torch.Tensor
topk: int
use_pre_softmax: bool
logits
: torch.Tensor
topk
: int
use_pre_softmax
: bool
if enabled, the computation order: softmax -> topk
num_groups: int
num_groups
: int
used in the group topk
group_topk: int
group_topk
: int
used in the group topk
scaling_factor: float
score_function: str
scaling_factor
: float
score_function
: str
currently only support softmax and sigmoid
expert_bias: torch.Tensor
expert_bias
: torch.Tensor
could be used in the sigmoid
Returns
-------
probs: torch.Tensor
routing_map: torch.Tensor
probs
: torch.Tensor
routing_map
: torch.Tensor
"""
if
logits
.
dtype
==
torch
.
float64
:
raise
ValueError
(
"Current TE does not support float64 router type"
)
...
...
@@ -186,15 +186,15 @@ def fused_compute_score_for_moe_aux_loss(
Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function.
Parameters
----------
logits: torch.Tensor
topk: int
score_function: str
logits
: torch.Tensor
topk
: int
score_function
: str
currently only support softmax and sigmoid
Returns
-------
routing_map: torch.Tensor
scores: torch.Tensor
routing_map
: torch.Tensor
scores
: torch.Tensor
"""
return
FusedComputeScoresForMoEAuxLoss
.
apply
(
logits
,
topk
,
score_function
)
...
...
@@ -258,18 +258,18 @@ def fused_moe_aux_loss(
Fused MoE aux loss.
Parameters
----------
probs: torch.Tensor
tokens_per_expert: torch.Tensor
probs
: torch.Tensor
tokens_per_expert
: torch.Tensor
the number of tokens per expert
total_num_tokens: int
total_num_tokens
: int
the total number of tokens, involved in the aux loss calculation
num_experts: int
topk: int
coeff: float
num_experts
: int
topk
: int
coeff
: float
the coefficient of the aux loss
Returns
-------
aux_loss: torch.scalar
aux_loss
: torch.scalar
"""
return
FusedAuxLoss
.
apply
(
probs
,
tokens_per_expert
,
total_num_tokens
,
num_experts
,
topk
,
coeff
)
transformer_engine/pytorch/setup.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -75,21 +75,29 @@ def get_platform():
def
get_wheel_url
():
"""Construct the wheel URL for the current platform."""
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
nvte_version
=
te_version
()
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
# For CUDA
11, we only compile for CUDA 11.8, and for CUDA
12 we only compile for CUDA 12.3
# For CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.3"
)
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
if
torch_cuda_version
.
major
==
12
:
torch_cuda_version
=
parse
(
"12.3"
)
elif
torch_cuda_version
.
major
==
13
:
torch_cuda_version
=
parse
(
"13.0"
)
else
:
raise
ValueError
(
f
"CUDA version
{
torch_cuda_version
}
not supported"
)
if
os
.
environ
.
get
(
"NVIDIA_PRODUCT_NAME"
,
""
)
==
"PyTorch"
:
torch_version
=
str
(
os
.
environ
.
get
(
"NVIDIA_PYTORCH_VERSION"
))
else
:
torch_version
=
f
"
{
torch
.
__version__
}
"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}
"
# Determine wheel URL based on CUDA version, torch version, python version and OS
...
...
@@ -109,8 +117,10 @@ class CachedWheelsCommand(_bdist_wheel):
"""
def
run
(
self
):
"""Acts a proxy before _bdist_wheel.run() and downloads a prebuilt wheel if available."""
if
FORCE_BUILD
:
super
().
run
()
return
wheel_url
,
wheel_filename
=
get_wheel_url
()
print
(
"Guessing wheel URL: "
,
wheel_url
)
...
...
@@ -129,10 +139,12 @@ class CachedWheelsCommand(_bdist_wheel):
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
print
(
"Raw wheel path"
,
wheel_path
)
os
.
rename
(
wheel_filename
,
wheel_path
)
return
except
(
urllib
.
error
.
HTTPError
,
urllib
.
error
.
URLError
):
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
super
().
run
()
return
if
__name__
==
"__main__"
:
...
...
transformer_engine/pytorch/tensor/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/tensor/_quantization_helpers.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with FP8 data quantized with NxN tiles"""
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
,
Union
from
collections.abc
import
Iterable
import
math
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch
import
transformer_engine_torch
as
tex
import
os
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
Float8BlockScaleTensorFormat
from
transformer_engine.common.recipe
import
Float8BlockScaling
,
Recipe
from
.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
..quantized_tensor
import
QuantizedTensor
,
Quantizer
...
...
@@ -38,8 +38,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon
:
float
force_pow_2_scales
:
bool
block_scaling_dim
:
int
# Whether to produce tensors that will be used in all-gather
all_gather_usage
:
bool
def
__init__
(
self
,
...
...
@@ -50,7 +48,6 @@ class Float8BlockQuantizer(Quantizer):
amax_epsilon
:
float
=
0.0
,
force_pow_2_scales
:
bool
=
True
,
block_scaling_dim
:
int
=
2
,
all_gather_usage
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8
else
fp8_dtype
...
...
@@ -58,7 +55,22 @@ class Float8BlockQuantizer(Quantizer):
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
amax_epsilon
=
amax_epsilon
self
.
block_scaling_dim
=
block_scaling_dim
self
.
all_gather_usage
=
all_gather_usage
def
copy
(
self
)
->
Float8BlockQuantizer
:
"""Create shallow copy"""
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
self
.
dtype
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
block_scaling_dim
=
self
.
block_scaling_dim
,
amax_epsilon
=
self
.
amax_epsilon
,
force_pow_2_scales
=
self
.
force_pow_2_scales
,
)
quantizer
.
internal
=
self
.
internal
quantizer
.
optimize_for_gemm
=
self
.
optimize_for_gemm
return
quantizer
def
update_quantized
(
self
,
...
...
@@ -110,103 +122,86 @@ class Float8BlockQuantizer(Quantizer):
return
tex
.
quantize
(
tensor
,
self
)
def
get_scale_shape
(
self
,
shape
:
Iterable
[
int
],
columnwise
:
bool
)
->
Tuple
[
int
,
int
]:
"""
Calculate the shape of the scaling tensor for blockwise quantization
.
"""
Scaling tensor shape
.
This method determines the shape of the scaling tensor
needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
The scales are padded to multiples of 4 on the inner dimension
for compatibility with GEMM.
This method determines the shape of the scaling tensor
based
on the quantizer configuration. The scales are padded to
multiples of 4
for compatibility with GEMM.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
Logical tensor shape.
columnwise : bool
Whether t
o use
columnwise
scaling
(True) or rowwise
scaling
(False)
Whether t
he data is scaled
column
-
wise (True) or row
-
wise (False)
.
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For 2D tensors:
- If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4))
- If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4))
For 1D tensors:
- If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4))
- If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4))
Scaling tensor shape.
"""
M
,
K
=
1
,
1
for
i
in
range
(
len
(
shape
)
-
1
):
M
*=
shape
[
i
]
if
len
(
shape
)
>
0
:
K
=
shape
[
-
1
]
# 2D 128x128 quantization block scaling
# CuBLAS requries 128x128 scaling factor to be padded
# currently rowwise and columnwise format option doesn't apply to 2D scaling
# Flatten tensor to 2D
dim0
=
math
.
prod
(
shape
[:
-
1
])
dim1
=
shape
[
-
1
]
if
shape
else
1
# Check block dims
if
self
.
block_scaling_dim
not
in
(
1
,
2
):
raise
RuntimeError
(
"Only 1D or 2D blocks are supported, "
f
"but got block_scaling_dim=
{
self
.
block_scaling_dim
}
"
)
# 128x128 block scaling
if
self
.
block_scaling_dim
==
2
:
scale_dim0
=
(
dim0
+
self
.
block_len
-
1
)
//
self
.
block_len
scale_dim1
=
(
dim1
+
self
.
block_len
-
1
)
//
self
.
block_len
if
columnwise
:
outer
=
math
.
ceil
(
K
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
M
/
self
.
block_len
),
4
)
return
(
outer
,
inner
)
# rowwise
outer
=
math
.
ceil
(
M
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
math
.
ceil
(
K
/
self
.
block_len
),
4
)
return
(
outer
,
inner
)
# 1D 1x128 quantization block scaling
# CuBLAS requries 1x128 scaling factor to be padded and transposed
assert
self
.
block_scaling_dim
==
1
,
"Only 1D or 2D blocks supported"
return
(
scale_dim1
,
round_up_to_nearest_multiple
(
scale_dim0
,
4
))
return
(
scale_dim0
,
round_up_to_nearest_multiple
(
scale_dim1
,
4
))
# 1x128 block scaling
if
columnwise
:
columnwise_compact
=
self
.
all_gather_usage
outer
=
math
.
ceil
(
M
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
K
,
4
)
if
not
columnwise_compact
else
K
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS
# for COMPACT case, since we apply 1x128 scaling here without transposing columnwise data, scaling factor is also [outer, inner]
# so no need to swap inner outer here
return
(
outer
,
inner
)
# rowwise
rowwise_compact
=
self
.
all_gather_usage
outer
=
math
.
ceil
(
K
/
self
.
block_len
)
inner
=
round_up_to_nearest_multiple
(
M
,
4
)
if
not
rowwise_compact
else
M
# GEMM READY case: scaling factor is [outer, inner], already transposed here for CuBLAS need
# for COMPACT case, since we apply 128x1 scaling, scaling block applies to inner dim, so we need to swap outer and inner here
return
(
outer
,
inner
)
if
not
rowwise_compact
else
(
inner
,
outer
)
return
(
(
dim0
+
self
.
block_len
-
1
)
//
self
.
block_len
,
round_up_to_nearest_multiple
(
dim1
,
4
),
)
return
(
(
dim1
+
self
.
block_len
-
1
)
//
self
.
block_len
,
round_up_to_nearest_multiple
(
dim0
,
4
),
)
def
get_columnwise_shape
(
self
,
shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
"""C
alculate the shape of a tensor after columnwise permutation.
"""C
olumn-wise data shape
This method rearranges the dimensions of a tensor to be columnwise,
moving the last dimension to the front and keeping the order of other dimensions
.
GEMMs expect that the column-wise data is transposed relative
to the logical tensor shape
.
Parameters
----------
shape : Iterable[int]
Ori
gi
n
al
shape of the tensor
Lo
gi
c
al
tensor shape.
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
Column-wise data shape.
"""
if
len
(
shape
)
==
0
:
return
tuple
()
# currently columnwise format option only applies to 1D quantizer
# for 2D scaling, columnwise format should always be GEMM_READY_DATA_AND_SCALES
# since currently 2D scaling only applies to module weights
if
self
.
block_scaling_dim
==
1
and
self
.
all_gather_usage
:
return
shape
colwise_shape
=
[
shape
[
-
1
]]
for
i
in
range
(
len
(
shape
)
-
1
):
colwise_shape
.
append
(
shape
[
i
])
colwise_shape
=
[]
if
shape
:
colwise_shape
.
append
(
shape
[
-
1
])
colwise_shape
.
extend
(
shape
[:
-
1
])
return
tuple
(
colwise_shape
)
def
is_quantizable
(
self
,
inp
:
torch
.
Tensor
)
->
bool
:
"""Returns whether or not given inp can be quantized"""
if
inp
.
ndim
<
2
:
shape
=
inp
.
size
()
if
len
(
shape
)
<
2
:
return
False
if
inp
.
shape
[
-
1
]
%
self
.
block_len
!=
0
:
if
shape
[
-
1
]
%
self
.
block_len
!=
0
:
return
False
if
math
.
prod
(
inp
.
shape
[:
-
1
])
%
self
.
block_len
!=
0
:
if
math
.
prod
(
shape
[:
-
1
])
%
self
.
block_len
!=
0
:
return
False
return
True
...
...
@@ -220,44 +215,36 @@ class Float8BlockQuantizer(Quantizer):
pin_memory
:
bool
=
False
,
)
->
Float8BlockwiseQTensor
:
"""Construct quantized tensor with uninitialized data"""
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
data_format
=
(
tex
.
Float8BlockScaleTensorFormat
.
COMPACT
if
self
.
all_gather_usage
else
tex
.
Float8BlockScaleTensorFormat
.
GEMM_READY
)
tensor_kwargs
=
{
"device"
:
torch
.
device
(
"cuda"
)
if
device
is
None
else
device
,
"pin_memory"
:
pin_memory
,
}
# Allocate
FP8
data
data
=
None
scale_inv
=
None
# Allocate
buffers for row-scaled
data
rowwise_
data
=
None
rowwise_
scale_inv
=
None
if
self
.
rowwise_usage
:
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
)
scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
False
)
scale_inv
=
torch
.
empty
(
scale_shape
,
rowwise_data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
**
tensor_kwargs
)
rowwise_scale_inv
=
torch
.
empty
(
self
.
get_scale_shape
(
shape
,
columnwise
=
False
),
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
,
**
tensor_kwargs
,
)
# Allocate
FP8 data transpose if needed
# Allocate
buffers for column-scaled data
columnwise_data
=
None
columnwise_scale_inv
=
None
if
self
.
columnwise_usage
:
columnwise_data
=
torch
.
empty
(
self
.
get_columnwise_shape
(
shape
),
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
,
**
tensor_kwargs
,
)
columnwise_scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
True
)
columnwise_scale_inv
=
torch
.
empty
(
columnwise
_scale_shape
,
self
.
get
_scale_shape
(
shape
,
columnwise
=
True
)
,
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
,
**
tensor_kwargs
,
)
# Construct FP8 tensor
...
...
@@ -265,13 +252,12 @@ class Float8BlockQuantizer(Quantizer):
shape
=
shape
,
dtype
=
dtype
,
fp8_dtype
=
self
.
dtype
,
rowwise_data
=
data
,
rowwise_scale_inv
=
scale_inv
,
rowwise_data
=
rowwise_
data
,
rowwise_scale_inv
=
rowwise_
scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
quantizer
=
self
,
is_2D_scaled
=
self
.
block_scaling_dim
==
2
,
data_format
=
data_format
,
requires_grad
=
requires_grad
,
)
...
...
@@ -294,18 +280,18 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
Parameters
----------
rowwise_data: torch.Tensor
rowwise_data
: torch.Tensor
FP8 data in a uint8 tensor matching shape of dequantized tensor.
rowwise_scale_inv: torch.Tensor
rowwise_scale_inv
: torch.Tensor
FP32 dequantization scales in GEMM format for dequantizing rowwise_data.
columnwise_data: Optional[torch.Tensor]
columnwise_data
: Optional[torch.Tensor]
FP8 data in a uint8 tensor matching shape of dequantized tensor transpose.
columnwise_scale_inv: Optional[torch.Tensor]
columnwise_scale_inv
: Optional[torch.Tensor]
FP32 dequantization scales in GEMM format for dequantizing columnwise_data.
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
fp8_dtype
: transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and
quantizer
: Quantizer - the Float8BlockQuantizer that quantized this tensor and
holds configuration about quantization and dequantization modes.
"""
...
...
@@ -321,7 +307,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype
:
TE_DType
,
quantizer
:
Quantizer
,
is_2D_scaled
:
bool
,
data_format
:
tex
.
Float8BlockScaleTensorFormat
=
Float8BlockScaleTensorFormat
.
GEMM_READY
,
**
kwargs
,
):
instance
=
super
().
__new__
(
...
...
@@ -333,7 +318,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
fp8_dtype
,
quantizer
,
is_2D_scaled
,
data_format
,
*
args
,
**
kwargs
,
)
...
...
@@ -344,8 +328,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
return
(
f
"Float8BlockwiseQTensor(fp8_dtype=
{
self
.
_fp8_dtype
}
,"
f
" is_2D_scaled=
{
self
.
_is_2D_scaled
}
,"
f
" data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
),"
f
" data_format=
{
self
.
_data_format
}
"
f
" data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
)"
)
def
quantize_
(
...
...
@@ -496,7 +479,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype
:
torch
.
dtype
,
quantizer
:
Quantizer
,
is_2D_scaled
:
bool
,
data_format
:
tex
.
Float8BlockScaleTensorFormat
,
data_format
:
Any
=
None
,
# pylint: disable=unused-argument
)
->
Float8BlockwiseQTensor
:
"""Build Float8BlockwiseQTensor, for use in __reduce__
...
...
@@ -514,7 +497,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dtype
=
dtype
,
quantizer
=
quantizer
,
is_2D_scaled
=
is_2D_scaled
,
data_format
=
data_format
,
)
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
...
...
@@ -531,7 +513,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
self
.
dtype
,
self
.
_quantizer
,
self
.
_is_2D_scaled
,
self
.
_
data_format
,
None
,
#
data_format
),
)
...
...
@@ -557,7 +539,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
dst
.
_fp8_dtype
=
src
.
_fp8_dtype
dst
.
_rowwise_scale_inv
=
src
.
_rowwise_scale_inv
dst
.
_columnwise_scale_inv
=
src
.
_columnwise_scale_inv
dst
.
_data_format
=
src
.
_data_format
# Check that tensor dimensions match
if
(
...
...
@@ -605,13 +586,6 @@ class _ViewFunc(torch.autograd.Function):
)
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if
not
tensor
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"View is only supported with GEMM_READY data format, "
f
"but found data_format=
{
tensor
.
_data_format
}
"
)
# Return input tensor if shape is not provided
ctx
.
shape
=
tensor
.
shape
if
shape
is
None
:
...
...
@@ -680,14 +654,6 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
# Check for invalid configurations
if
not
grad
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"View is only supported with GEMM_READY data format, "
f
"but found data_format=
{
grad
.
_data_format
}
"
)
new_data
=
(
grad
.
_rowwise_data
.
view
(
*
ctx
.
shape
)
if
grad
.
_rowwise_data
is
not
None
else
None
)
...
...
@@ -727,13 +693,6 @@ class _ReshapeFunc(torch.autograd.Function):
)
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
# Check for invalid configurations
if
not
tensor
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"Reshape is only supported with GEMM_READY data format, "
f
"but found data_format=
{
tensor
.
_data_format
}
"
)
# Return input tensor if shape is not provided
ctx
.
shape
=
tensor
.
shape
if
shape
is
None
:
...
...
@@ -801,14 +760,6 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
# Check for invalid configurations
if
not
grad
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"Reshape is only supported with GEMM_READY data format, "
f
"but found data_format=
{
grad
.
_data_format
}
"
)
new_rowwise_data
=
None
new_columnwise_data
=
None
if
grad
.
_rowwise_data
is
not
None
:
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -67,6 +67,20 @@ class Float8Quantizer(Quantizer):
self
.
amax
=
amax
self
.
dtype
=
fp8_dtype
def
copy
(
self
)
->
Float8Quantizer
:
"""Create shallow copy"""
quantizer
=
Float8Quantizer
(
scale
=
self
.
scale
,
amax
=
self
.
amax
,
fp8_dtype
=
self
.
dtype
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
)
quantizer
.
internal
=
self
.
internal
return
quantizer
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
...
...
@@ -246,10 +260,16 @@ class Float8CurrentScalingQuantizer(Quantizer):
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
force_pow_2_scales
:
bool
=
False
,
amax_epsilon
:
float
=
0.0
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
amax
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
scale
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
if
scale
is
None
:
scale
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
if
amax
is
None
:
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
scale
=
scale
self
.
amax
=
amax
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8_tensorwise
else
fp8_dtype
self
.
use_existing_amax
=
use_existing_amax
self
.
with_amax_reduction
=
with_amax_reduction
...
...
@@ -257,6 +277,27 @@ class Float8CurrentScalingQuantizer(Quantizer):
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
amax_epsilon
=
amax_epsilon
def
copy
(
self
)
->
Float8CurrentScalingQuantizer
:
"""Create shallow copy"""
quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
self
.
dtype
,
device
=
0
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
with_amax_reduction
=
self
.
with_amax_reduction
,
amax_reduction_group
=
self
.
amax_reduction_group
,
use_existing_amax
=
self
.
use_existing_amax
,
force_pow_2_scales
=
self
.
force_pow_2_scales
,
amax_epsilon
=
self
.
amax_epsilon
,
scale
=
self
.
scale
,
amax
=
self
.
amax
,
)
quantizer
.
internal
=
self
.
internal
quantizer
.
optimize_for_gemm
=
self
.
optimize_for_gemm
return
quantizer
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
...
...
@@ -414,23 +455,23 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
Parameters
----------
shape: int or iterable of int
shape
: int or iterable of int
Tensor dimensions.
dtype: torch.dtype
dtype
: torch.dtype
Nominal tensor datatype.
requires_grad: bool, optional = False
requires_grad
: bool, optional = False
Whether to compute gradients for this tensor.
data: torch.Tensor
data
: torch.Tensor
Raw FP8 data in a uint8 tensor
fp8_scale_inv: torch.Tensor
fp8_scale_inv
: torch.Tensor
Reciprocal of the scaling factor applied when casting to FP8,
i.e. the scaling factor that must be applied when casting from
FP8 to higher precision.
fp8_dtype: transformer_engine_torch.DType
fp8_dtype
: transformer_engine_torch.DType
FP8 format.
data_transpose: torch.Tensor, optional
data_transpose
: torch.Tensor, optional
FP8 transpose data in a uint8 tensor
quantizer: Float8Quantizer, Float8CurrentScalingQuantizer, optional
quantizer
: Float8Quantizer, Float8CurrentScalingQuantizer, optional
Builder class for FP8 tensors
"""
...
...
@@ -454,10 +495,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
# Convert PyTorch dtype to TE dtype
if
dtype
is
None
:
dtype
=
self
.
dtype
tensor
=
self
.
contiguous
()
if
torch
.
is_grad_enabled
():
return
_FromFloat8Func
.
apply
(
self
,
dtype
)
return
_FromFloat8Func
.
forward
(
None
,
self
,
dtype
)
return
_FromFloat8Func
.
apply
(
tensor
,
dtype
)
return
_FromFloat8Func
.
forward
(
None
,
tensor
,
dtype
)
def
quantize_
(
self
,
...
...
@@ -512,18 +553,31 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
)
->
Float8Tensor
:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
Returns
`
`self`
`
if data is already in correct memory format.
"""
if
self
.
_data
is
not
None
and
self
.
_data
.
is_contiguous
(
memory_format
=
memory_format
):
return
self
if
self
.
_transpose
is
not
None
and
self
.
_transpose
.
is_contiguous
(
# Check if tensor already has correct memory format
if
self
.
_data
is
not
None
and
not
self
.
_data
.
is_contiguous
(
memory_format
=
memory_format
):
pass
elif
self
.
_transpose
is
not
None
and
not
self
.
_transpose
.
is_contiguous
(
memory_format
=
memory_format
):
pass
else
:
# Tensor has correct memory format, so return immediately
return
self
return
Float8Tensor
.
make_like
(
tensor
=
self
,
data
=
self
.
_data
.
contiguous
())
# raise ValueError("Float8Tensor does not support different memory formats!")
# Construct tensor with correct data format
data
,
data_transpose
=
None
,
None
if
self
.
_data
is
not
None
:
data
=
self
.
_data
.
contiguous
(
memory_format
=
memory_format
)
if
self
.
_transpose
is
not
None
and
not
self
.
_transpose_invalid
:
data_transpose
=
self
.
_transpose
.
contiguous
(
memory_format
=
memory_format
)
return
_IdentityFunc
.
apply
(
self
,
{
"data"
:
data
,
"data_transpose"
:
data_transpose
},
)
def
_reset_caches
(
self
)
->
None
:
"""
...
...
@@ -674,9 +728,8 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[
transpose
,
t_shape
]
+
list
(
args
[
2
:]),
kwargs
,
)
# deep copy the scale inverse tensor and quantizer as well.
scale_inv
=
tensor
.
_scale_inv
.
detach
().
clone
()
quantizer
=
tensor
.
_quantizer
.
copy
()
quantizer
=
tensor
.
_quantizer
# Deep-copied in constructor
out_tensor
=
Float8Tensor
(
data
=
func_out
,
shape
=
func_out
.
shape
,
...
...
@@ -781,7 +834,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
# sure that updated Quantized weight tensor have same scale inverse across all shards.
self
.
_quantizer
.
amax_reduction_group
=
mesh
.
get_group
()
self
.
_quantizer
.
with_amax_reduction
=
True
quantizer
=
self
.
_quantizer
.
copy
()
# quantizer to be used for allgathered weights
fsdp_state
=
_get_module_fsdp_state
(
module
)
reshard_after_forward
=
fsdp_state
.
_fsdp_param_group
.
_reshard_after_forward
# If weights are resharded after forward pass, then its enough to set the quantizer usages
...
...
@@ -794,9 +847,13 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
is_backward_pass
=
training_state
==
TrainingState
.
PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately.
quantizer
.
set_usage
(
rowwise
=
not
is_backward_pass
,
columnwise
=
is_backward_pass
)
rowwise_usage
=
not
is_backward_pass
columnwise_usage
=
is_backward_pass
else
:
rowwise_usage
=
True
columnwise_usage
=
self
.
_quantizer
.
columnwise_usage
sharded_tensors
=
(
self
.
_data
,)
metadata
=
(
self
.
_scale_inv
,
self
.
_fp8_dtype
,
quantizer
)
metadata
=
(
self
.
_scale_inv
,
rowwise_usage
,
columnwise_usage
,
self
.
_fp8_dtype
)
return
sharded_tensors
,
metadata
def
fsdp_post_all_gather
(
...
...
@@ -822,7 +879,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
"""
(
data
,)
=
all_gather_outputs
(
fp8_scale_inv
,
fp8_dtype
,
quantizer
)
=
metadata
(
fp8_scale_inv
,
rowwise_usage
,
columnwise_usage
,
fp8_dtype
)
=
metadata
orig_shape
=
data
.
size
()
# Quantizer has only columnwise usage set for backward pass
# In Blackwell+ architectures, transpose is not needed at all,
...
...
@@ -831,20 +888,27 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
if
out
is
not
None
:
out
.
_data
=
data
else
:
# We ll be here when post all gather is called the first time.
# Float8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration,
# since we need different quantizers for sharded and allgathered tensors.
# and self._quantizer belongs to the sharded parameter.
fp8_args
=
{
"shape"
:
orig_shape
,
"dtype"
:
param_dtype
,
"fp8_scale_inv"
:
fp8_scale_inv
,
"fp8_dtype"
:
fp8_dtype
,
"quantizer"
:
quantizer
,
"quantizer"
:
self
.
_
quantizer
,
"requires_grad"
:
False
,
"data"
:
data
,
}
out
=
Float8Tensor
(
**
fp8_args
)
out
.
_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
)
out
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
rowwise_usage
=
rowwise_usage
,
columnwise_usage
=
columnwise_usage
,
)
return
out
,
all_gather_outputs
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -45,6 +45,19 @@ class MXFP8Quantizer(Quantizer):
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
fp8_dtype
def
copy
(
self
)
->
MXFP8Quantizer
:
"""Create shallow copy"""
quantizer
=
MXFP8Quantizer
(
fp8_dtype
=
self
.
dtype
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
)
quantizer
.
internal
=
self
.
internal
quantizer
.
optimize_for_gemm
=
self
.
optimize_for_gemm
return
quantizer
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
...
...
@@ -122,7 +135,9 @@ class MXFP8Quantizer(Quantizer):
columnwise_data
=
None
columnwise_scale_inv
=
None
if
self
.
columnwise_usage
:
columnwise_data
=
torch
.
empty_like
(
data
,
pin_memory
=
pin_memory
)
columnwise_data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
)
columnwise_scale_inv
=
torch
.
empty
(
round_up_to_nearest_multiple
(
math
.
prod
(
shape
[:
-
1
])
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
round_up_to_nearest_multiple
(
shape
[
-
1
],
128
),
...
...
@@ -142,6 +157,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv
=
columnwise_scale_inv
,
quantizer
=
self
,
requires_grad
=
requires_grad
,
with_gemm_swizzled_scales
=
self
.
optimize_for_gemm
,
)
def
calibrate
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
...
...
@@ -165,6 +181,7 @@ class MXFP8Quantizer(Quantizer):
columnwise_scale_inv
=
None
,
fp8_dtype
=
fp8_dtype
,
quantizer
=
self
,
with_gemm_swizzled_scales
=
False
,
)
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
...
...
@@ -174,6 +191,10 @@ class MXFP8Quantizer(Quantizer):
return
self
.
create_tensor_from_data
(
data
,
scale_inv
,
fake_dtype
=
torch
.
float32
)
def
onnx_dequantize
(
self
,
tensor
:
Union
[
MXFP8TensorStorage
,
MXFP8Tensor
])
->
torch
.
Tensor
:
if
tensor
.
_with_gemm_swizzled_scales
:
raise
NotImplementedError
(
"ONNX MXFP8 dequantization is only supported with scales in compact format."
)
return
torch
.
ops
.
tex
.
mxfp8_dequantize
(
tensor
.
_rowwise_data
,
tensor
.
_rowwise_scale_inv
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
...
...
@@ -190,16 +211,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
Parameters
----------
data: torch.Tensor
data
: torch.Tensor
Raw FP8 data in a uint8 tensor
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
fp8_dtype
: transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
fp8_scale_inv: torch.Tensor
fp8_scale_inv
: torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher
precision.
dtype: torch.dtype, default = torch.float32
dtype
: torch.dtype, default = torch.float32
Nominal tensor datatype.
"""
...
...
@@ -215,9 +236,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv
:
Optional
[
torch
.
Tensor
],
fp8_dtype
:
TE_DType
,
quantizer
:
Optional
[
Quantizer
],
with_gemm_swizzled_scales
:
bool
,
**
kwargs
,
):
instance
=
super
().
__new__
(
return
super
().
__new__
(
cls
,
rowwise_data
,
rowwise_scale_inv
,
...
...
@@ -225,10 +247,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv
,
fp8_dtype
,
quantizer
,
with_gemm_swizzled_scales
,
*
args
,
**
kwargs
,
)
return
instance
def
__repr__
(
self
,
*
,
tensor_contents
=
None
):
return
f
"MXFP8Tensor(fp8_dtype=
{
self
.
_fp8_dtype
}
, data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
)"
...
...
@@ -320,39 +342,44 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
# View op
if
func
==
aten
.
view
.
default
:
tensor
=
args
[
0
]
data
=
tensor
.
_rowwise_data
out_data
=
data
.
__torch_dispatch__
(
func
,
types
,
[
data
]
+
list
(
args
[
1
:]),
kwargs
,
)
out_shape
=
out_data
.
size
()
shape
=
args
[
1
]
if
len
(
shape
)
<
2
or
shape
[
-
1
]
!=
tensor
.
size
(
-
1
):
raise
ValueError
(
f
"Attempted to make view with size=
{
tuple
(
shape
)
}
"
f
"from MXFP8 tensor with shape=
{
tuple
(
tensor
.
size
())
}
."
)
rowwise_data_view
=
None
columnwise_data_view
=
None
if
tensor
.
_rowwise_data
is
not
None
:
rowwise_data_view
=
tensor
.
_rowwise_data
.
view
(
shape
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_data_view
=
tensor
.
_columnwise_data
.
view
(
shape
)
return
MXFP8Tensor
(
shape
=
out_
shape
,
shape
=
shape
,
dtype
=
tensor
.
dtype
,
rowwise_data
=
out_data
,
rowwise_data
=
rowwise_data_view
,
rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
,
columnwise_data
=
tensor
.
_
columnwise_data
,
columnwise_data
=
columnwise_data
_view
,
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
,
quantizer
=
tensor
.
_quantizer
,
requires_grad
=
False
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
with_gemm_swizzled_scales
=
tensor
.
_with_gemm_swizzled_scales
,
)
if
func
==
torch
.
ops
.
aten
.
copy_
.
default
:
dst
,
src
=
args
[
0
],
args
[
1
]
if
isinstance
(
src
,
MXFP8Tensor
)
and
isinstance
(
dst
,
MXFP8Tensor
):
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages.
# If not, default to base class behavior.
rowwise_matches
=
src
.
_rowwise_data
is
not
None
or
dst
.
_rowwise_data
is
None
columnwise_matches
=
(
src
.
_columnwise_data
is
not
None
or
dst
.
_columnwise_data
is
None
)
if
rowwise_matches
and
columnwise_matches
:
if
src
.
_rowwise_data
is
None
and
dst
.
_rowwise_data
is
not
None
:
pass
elif
src
.
_columnwise_data
is
None
and
dst
.
_columnwise_data
is
not
None
:
pass
elif
src
.
_with_gemm_swizzled_scales
!=
dst
.
_with_gemm_swizzled_scales
:
pass
else
:
# src and dst match, so we can directly copy data
if
dst
.
_rowwise_data
is
not
None
:
dst
.
_rowwise_data
.
copy_
(
src
.
_rowwise_data
.
detach
(),
*
args
[
2
:],
**
kwargs
)
dst
.
_rowwise_scale_inv
.
copy_
(
...
...
@@ -367,26 +394,25 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
)
return
dst
# FSDP2 related functions.
if
func
==
aten
.
split
.
Tensor
:
# This is called if entire model is initialized on CUDA device and
# then splitted. Finally the shard needed by the process is used
# and other splitted shards are discarded.
# With FSDP2, this is called if entire model is
# initialized on CUDA device and then splitted. Finally
# the shard needed by the process is used and other
# splitted shards are discarded.
tensor
=
args
[
0
]
split_size
=
args
[
1
]
if
"dim"
in
kwargs
:
dim_to_split
=
kwargs
[
"dim"
]
else
:
dim_to_split
=
args
[
2
]
if
len
(
args
)
>
2
else
0
tensor
=
args
[
0
]
split_size
=
args
[
1
]
dim0_size
=
tensor
.
size
(
0
)
dimlast_size
=
math
.
prod
(
tensor
.
shape
[
1
:])
# Fall back to high-precision if split is non-trivial
if
(
dim
0_size
%
split_size
!=
0
or
dim_to_split
!=
0
dim
_to_split
!=
0
or
tensor
.
size
(
0
)
%
split_size
!=
0
or
split_size
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
or
dimlast
_siz
e
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
or
tensor
.
_with_gemm
_s
w
iz
zled_scales
):
# Handle splitting by dequantizing and splitting the hp tensor
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
out_data
=
[]
...
...
@@ -420,13 +446,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if
scale_inv
is
not
None
else
None
)
scale_inv_out
=
list
(
scale_inv_out
)
if
scale_inv_out
is
not
None
else
None
# Pad scale_inv_out to be a multiple of pad_multiple
if
scale_inv_out
is
not
None
:
current_shape
=
scale_inv_out
.
shape
pad_dim0
=
(
pad_multiple
-
current_shape
[
0
]
%
pad_multiple
)
%
pad_multiple
if
pad_dim0
>
0
:
scale_inv_out
=
torch
.
nn
.
functional
.
pad
(
scale_inv_out
,
(
0
,
0
,
0
,
pad_dim0
))
for
idx
,
split_scale_inv_out
in
enumerate
(
scale_inv_out
):
current_shape
=
split_scale_inv_out
.
shape
pad_dim0
=
(
pad_multiple
-
current_shape
[
0
]
%
pad_multiple
)
%
pad_multiple
if
pad_dim0
>
0
:
scale_inv_out
[
idx
]
=
torch
.
nn
.
functional
.
pad
(
split_scale_inv_out
,
(
0
,
0
,
0
,
pad_dim0
)
)
out_data
.
append
(
scale_inv_out
)
return
[
MXFP8Tensor
(
...
...
@@ -443,28 +472,26 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
quantizer
=
tensor
.
_quantizer
,
requires_grad
=
False
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
with_gemm_swizzled_scales
=
False
,
)
for
splitted_tensor_data
in
zip
(
*
out_data
)
]
if
func
==
torch
.
ops
.
aten
.
as_strided
.
default
:
# Applied on unsharded param in FSDP2. In our case, this should be a no-op
# This is needed for the case where some MXFP8 shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision.
# If weight doesnt need padding, this is just a no-op.
tensor
=
args
[
0
]
shape
=
args
[
1
]
strides
=
args
[
2
]
tensor
=
args
[
0
]
if
(
len
(
shape
)
!=
2
or
len
(
strides
)
!=
2
or
strides
[
1
]
!=
1
or
shape
[
0
]
!=
tensor
.
shape
[
0
]
or
shape
[
1
]
!=
tensor
.
shape
[
1
]
len
(
shape
)
==
len
(
strides
)
==
2
and
tuple
(
strides
)
==
(
shape
[
-
1
],
1
)
and
tuple
(
shape
)
==
tuple
(
tensor
.
size
())
):
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
return
MXFP8Tensor
.
make_like
(
tensor
)
return
MXFP8Tensor
.
make_like
(
tensor
)
if
func
==
aten
.
slice
.
Tensor
:
# FSDP2 needed function.
...
...
@@ -472,19 +499,12 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision instead.
# If sharded weight doesnt have padding, this is just a no-op.
tensor
=
args
[
0
]
dim
=
args
[
1
]
start
=
args
[
2
]
length
=
args
[
3
]
tensor
=
args
[
0
]
if
(
dim
!=
0
or
length
!=
tensor
.
shape
[
0
]
or
start
!=
0
or
length
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
or
start
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
):
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
return
MXFP8Tensor
.
make_like
(
tensor
)
if
start
==
0
and
length
==
tensor
.
size
(
dim
):
return
MXFP8Tensor
.
make_like
(
tensor
)
if
func
==
aten
.
new_zeros
.
default
:
rowwise_data
=
None
...
...
@@ -538,10 +558,12 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
quantizer
=
tensor
.
_quantizer
.
copy
()
,
quantizer
=
tensor
.
_quantizer
,
requires_grad
=
False
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
with_gemm_swizzled_scales
=
tensor
.
_with_gemm_swizzled_scales
,
)
# Default case
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
...
...
@@ -567,29 +589,32 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
# pylint: disable=unused-argument
from
transformer_engine.pytorch.distributed
import
_get_module_fsdp_state
# Get FSDP state
fsdp_state
=
_get_module_fsdp_state
(
module
)
reshard_after_forward
=
fsdp_state
.
_fsdp_param_group
.
_reshard_after_forward
quantizer
=
self
.
_quantizer
.
copy
()
# Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv
=
self
.
_rowwise_scale_inv
columnwise_scale_inv
=
self
.
_columnwise_scale_inv
shape
=
self
.
shape
if
self
.
_with_gemm_swizzled_scales
:
raise
NotImplementedError
(
"FSDP2 is only supported for MXFP8Tensors with compact scales"
)
if
rowwise_scale_inv
is
not
None
:
# Remove padding from rowwise scale_inv
flattened_in_shape0
=
math
.
prod
(
shape
[:
-
1
])
if
rowwise_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
rowwise_scale_inv
=
rowwise_scale_inv
[:
flattened_in_shape0
]
if
columnwise_scale_inv
is
not
None
:
# Remove padding from columnwise scale_inv
flattened_in_shape0
=
math
.
prod
(
shape
[:
-
1
])
//
MXFP8_BLOCK_SCALING_SIZE
if
columnwise_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
columnwise_scale_inv
=
columnwise_scale_inv
[:
flattened_in_shape0
]
sharded_tensors
=
(
self
.
_rowwise_data
,
rowwise_scale_inv
)
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If weights are resharded after forward pass, then its enough to send one row/col
# usage based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself.
...
...
@@ -597,18 +622,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
training_state
=
fsdp_state
.
_fsdp_param_group
.
_training_state
is_backward_pass
=
training_state
==
TrainingState
.
PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass
quantizer
.
set_usage
(
rowwise
=
not
is_backward_pass
,
columnwise
=
is_backward_pass
)
rowwise_usage
=
not
is_backward_pass
columnwise_usage
=
is_backward_pass
sharded_tensors
=
(
(
self
.
_columnwise_data
,
columnwise_scale_inv
)
if
is_backward_pass
else
sharded_tensors
else
(
self
.
_rowwise_data
,
rowwise_scale_inv
)
)
else
:
if
quantizer
.
columnwise_usage
:
# rowwise usage is always needed for forward pass.
rowwise_usage
=
True
sharded_tensors
=
(
self
.
_rowwise_data
,
rowwise_scale_inv
)
columnwise_usage
=
self
.
_quantizer
.
columnwise_usage
if
columnwise_usage
:
# If weights are not resharded after forward, then both
# rowwise and columnwise data/scale_inv need to be allgathered.
sharded_tensors
+=
(
self
.
_columnwise_data
,
columnwise_scale_inv
)
metadata
=
(
self
.
_fp8_dtype
,
quantizer
)
metadata
=
(
self
.
_fp8_dtype
,
rowwise_usage
,
columnwise_usage
)
return
sharded_tensors
,
metadata
def
fsdp_post_all_gather
(
...
...
@@ -631,12 +662,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors
used by the MXFP8Tensor that was being computed after allgather.
"""
fp8_dtype
,
quantizer
=
metadata
rowwise_data
,
rowwise_scale_inv
=
(
all_gather_outputs
[:
2
]
if
quantizer
.
rowwise_usage
else
(
None
,
None
)
)
fp8_dtype
,
rowwise_usage
,
columnwise_usage
=
metadata
rowwise_data
,
rowwise_scale_inv
=
all_gather_outputs
[:
2
]
if
rowwise_usage
else
(
None
,
None
)
columnwise_data
,
columnwise_scale_inv
=
(
all_gather_outputs
[
-
2
:]
if
quantizer
.
columnwise_usage
else
(
None
,
None
)
all_gather_outputs
[
-
2
:]
if
columnwise_usage
else
(
None
,
None
)
)
# Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise
...
...
@@ -661,8 +690,13 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
out
.
_rowwise_scale_inv
=
rowwise_scale_inv
out
.
_columnwise_data
=
columnwise_data
out
.
_columnwise_scale_inv
=
columnwise_scale_inv
out
.
_quantizer
=
quantizer
else
:
# We'll be here when post all gather is called the first time.
# MXFP8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration,
# since we need different quantizers for sharded and allgathered tensors.
# and self._quantizer belongs to the sharded parameter.
out
=
MXFP8Tensor
(
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
...
...
@@ -671,9 +705,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fp8_dtype
=
fp8_dtype
,
dtype
=
param_dtype
,
shape
=
rowwise_data
.
shape
if
rowwise_data
is
not
None
else
columnwise_data
.
shape
,
quantizer
=
quantizer
,
quantizer
=
self
.
_quantizer
,
with_gemm_swizzled_scales
=
False
,
)
out
.
_quantizer
.
set_usage
(
rowwise
=
rowwise_usage
,
columnwise
=
columnwise_usage
)
return
out
,
all_gather_outputs
@
classmethod
...
...
@@ -687,6 +722,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype
:
torch
.
dtype
,
shape
:
torch
.
shape
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
with_gemm_swizzled_scales
:
bool
=
False
,
)
->
MXFP8Tensor
:
"""Build MXFP8Tensor, for use in __reduce__
...
...
@@ -703,6 +739,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
dtype
=
dtype
,
shape
=
shape
,
quantizer
=
quantizer
,
with_gemm_swizzled_scales
=
with_gemm_swizzled_scales
,
)
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
...
...
@@ -718,6 +755,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self
.
dtype
,
self
.
shape
,
self
.
_quantizer
,
self
.
_with_gemm_swizzled_scales
,
),
)
...
...
@@ -739,7 +777,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if
not
devices_match
(
new_device
,
tensor
.
device
):
tensor
=
tensor
.
to
(
device
=
new_device
)
# Just copy
FP8
data if other tensor is MXFP8Tensor
# Just copy data if other tensor is MXFP8Tensor
if
isinstance
(
tensor
,
MXFP8Tensor
):
if
(
# pylint: disable=too-many-boolean-expressions
self
.
size
()
!=
tensor
.
size
()
...
...
@@ -767,6 +805,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
self
.
_fp8_dtype
=
tensor
.
_fp8_dtype
self
.
_rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
self
.
_columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
self
.
_with_gemm_swizzled_scales
=
tensor
.
_with_gemm_swizzled_scales
return
# Quantize to FP8
...
...
@@ -838,6 +877,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
quantizer
=
tensor
.
_quantizer
,
with_gemm_swizzled_scales
=
tensor
.
_with_gemm_swizzled_scales
,
)
@
staticmethod
...
...
@@ -864,6 +904,7 @@ class _ViewFunc(torch.autograd.Function):
columnwise_scale_inv
=
grad
.
_columnwise_scale_inv
,
fp8_dtype
=
grad
.
_fp8_dtype
,
quantizer
=
grad
.
_quantizer
,
with_gemm_swizzled_scales
=
grad
.
_with_gemm_swizzled_scales
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
...
...
@@ -924,6 +965,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
quantizer
=
tensor
.
_quantizer
,
with_gemm_swizzled_scales
=
tensor
.
_with_gemm_swizzled_scales
,
)
@
staticmethod
...
...
@@ -949,6 +991,7 @@ class _ReshapeFunc(torch.autograd.Function):
columnwise_scale_inv
=
grad
.
_columnwise_scale_inv
,
fp8_dtype
=
grad
.
_fp8_dtype
,
quantizer
=
grad
.
_quantizer
,
with_gemm_swizzled_scales
=
grad
.
_with_gemm_swizzled_scales
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
transformer_engine/pytorch/tensor/nvfp4_tensor.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -28,9 +28,9 @@ from ._quantization_helpers import _IdentityFunc
aten
=
torch
.
ops
.
aten
def
get_no_random_sign_vector
()
->
torch
.
Tensor
:
def
get_no_random_sign_vector
(
device
:
int
)
->
torch
.
Tensor
:
"""Non-random sign vector for Hadamard transform."""
return
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
return
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
,
device
=
device
)
def
get_sign_from_vector
(
vector
:
torch
.
Tensor
)
->
int
:
...
...
@@ -45,7 +45,7 @@ def get_sign_from_vector(vector: torch.Tensor) -> int:
return
mask
.
item
()
def
get_wgrad_sign_vector
()
->
torch
.
Tensor
:
def
get_wgrad_sign_vector
(
device
:
int
)
->
torch
.
Tensor
:
"""Hard-coded random signs for Hadamard transform.
https://xkcd.com/221/
...
...
@@ -54,11 +54,11 @@ def get_wgrad_sign_vector() -> torch.Tensor:
return
torch
.
tensor
(
[
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
-
1
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
device
,
)
def
get_hadamard_matrix
(
hadamard_dimension
:
int
)
->
torch
.
Tensor
:
def
get_hadamard_matrix
(
hadamard_dimension
:
int
,
device
:
int
)
->
torch
.
Tensor
:
"""Construct a 16x16 Hadamard matrix."""
assert
hadamard_dimension
==
16
,
"Only hadamard dimension 16 is supported."
hadamard_scale
=
1
/
math
.
sqrt
(
hadamard_dimension
)
...
...
@@ -83,30 +83,30 @@ def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
[
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
device
,
)
*
hadamard_scale
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
get_rht_matrix
(
with_random_sign_mask
:
bool
)
->
torch
.
Tensor
:
def
get_rht_matrix
(
with_random_sign_mask
:
bool
,
device
:
int
)
->
torch
.
Tensor
:
"""Construct matrix used in random Hadamard transform."""
hadamard_dimension
=
16
if
with_random_sign_mask
:
signs
=
get_wgrad_sign_vector
()
signs
=
get_wgrad_sign_vector
(
device
=
device
)
else
:
signs
=
get_no_random_sign_vector
()
sign_matrix
=
signs
*
torch
.
eye
(
hadamard_dimension
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
rht_matrix
=
sign_matrix
@
get_hadamard_matrix
(
hadamard_dimension
)
signs
=
get_no_random_sign_vector
(
device
=
device
)
sign_matrix
=
signs
*
torch
.
eye
(
hadamard_dimension
,
dtype
=
torch
.
float32
,
device
=
device
)
rht_matrix
=
sign_matrix
@
get_hadamard_matrix
(
hadamard_dimension
,
device
=
device
)
return
rht_matrix
.
to
(
dtype
=
torch
.
bfloat16
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
get_random_sign_mask_for_rht
(
with_random_sign_mask
:
bool
)
->
int
:
def
get_random_sign_mask_for_rht
(
with_random_sign_mask
:
bool
,
device
:
int
)
->
int
:
"""Sign mask for random Hadamard transform."""
if
with_random_sign_mask
:
return
get_sign_from_vector
(
get_wgrad_sign_vector
())
return
get_sign_from_vector
(
get_wgrad_sign_vector
(
device
=
device
))
return
0
...
...
@@ -152,8 +152,10 @@ class NVFP4Quantizer(Quantizer):
self
.
amax_reduction_group
=
amax_reduction_group
self
.
with_2d_quantization
=
with_2d_quantization
self
.
stochastic_rounding
=
stochastic_rounding
self
.
rht_matrix_random_sign_mask_t
=
get_random_sign_mask_for_rht
(
with_random_sign_mask
)
self
.
rht_matrix
=
get_rht_matrix
(
with_random_sign_mask
)
self
.
rht_matrix_random_sign_mask_t
=
get_random_sign_mask_for_rht
(
with_random_sign_mask
,
torch
.
cuda
.
current_device
()
)
self
.
rht_matrix
=
get_rht_matrix
(
with_random_sign_mask
,
torch
.
cuda
.
current_device
())
def
update_quantized
(
self
,
...
...
@@ -176,6 +178,27 @@ class NVFP4Quantizer(Quantizer):
return
dst
def
copy
(
self
)
->
NVFP4Quantizer
:
"""Create shallow copy"""
quantizer
=
NVFP4Quantizer
(
fp4_dtype
=
self
.
dtype
,
rowwise
=
self
.
rowwise_usage
,
columnwise
=
self
.
columnwise_usage
,
with_amax_reduction
=
self
.
with_amax_reduction
,
amax_reduction_group
=
self
.
amax_reduction_group
,
with_rht
=
self
.
with_rht
,
with_post_rht_amax
=
self
.
with_post_rht_amax
,
with_2d_quantization
=
self
.
with_2d_quantization
,
stochastic_rounding
=
self
.
stochastic_rounding
,
)
quantizer
.
internal
=
self
.
internal
quantizer
.
optimize_for_gemm
=
self
.
optimize_for_gemm
quantizer
.
rht_matrix
=
self
.
rht_matrix
quantizer
.
rht_matrix_random_sign_mask_t
=
self
.
rht_matrix_random_sign_mask_t
return
quantizer
def
quantize_impl
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Quantize tensor implementation"""
return
tex
.
quantize
(
tensor
,
self
)
...
...
@@ -337,6 +360,7 @@ class NVFP4Quantizer(Quantizer):
fp4_dtype
=
self
.
dtype
,
quantizer
=
self
,
requires_grad
=
requires_grad
,
with_gemm_swizzled_scales
=
False
,
)
def
calibrate
(
self
,
tensor
:
torch
.
Tensor
)
->
None
:
...
...
@@ -360,26 +384,26 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
Parameters
----------
rowwise_data: torch.Tensor
rowwise_data
: torch.Tensor
Raw FP4 data in a uint8 tensor (rowwise layout).
rowwise_scale_inv: torch.Tensor
rowwise_scale_inv
: torch.Tensor
Reciprocal of the scaling factor applied when
casting to FP4, i.e. the scaling factor that must
be applied when casting from FP4 to higher
precision (rowwise).
columnwise_data: torch.Tensor, optional
columnwise_data
: torch.Tensor, optional
Raw FP4 data in a uint8 tensor (columnwise layout).
columnwise_scale_inv: torch.Tensor, optional
columnwise_scale_inv
: torch.Tensor, optional
Reciprocal of the scaling factor for columnwise FP4 data.
amax_rowwise: torch.Tensor, optional
amax_rowwise
: torch.Tensor, optional
Rowwise amax tracking tensor.
amax_columnwise: torch.Tensor, optional
amax_columnwise
: torch.Tensor, optional
Columnwise amax tracking tensor.
fp4_dtype: TE_DType
fp4_dtype
: TE_DType
The FP4 data type used for quantization.
quantizer: Quantizer
quantizer
: Quantizer
The quantizer instance used for this tensor.
dtype: torch.dtype, default = torch.float32
dtype
: torch.dtype, default = torch.float32
Nominal tensor datatype, used in dequantize.
"""
...
...
@@ -396,6 +420,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise
:
Optional
[
torch
.
Tensor
],
fp4_dtype
:
TE_DType
,
quantizer
:
Quantizer
,
with_gemm_swizzled_scales
:
bool
,
**
kwargs
,
):
instance
=
super
().
__new__
(
...
...
@@ -408,6 +433,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise
,
fp4_dtype
,
quantizer
,
with_gemm_swizzled_scales
,
*
args
,
**
kwargs
,
)
...
...
@@ -570,6 +596,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise
=
amax_columnwise
,
quantizer
=
tensor
.
_quantizer
,
requires_grad
=
tensor
.
requires_grad
,
with_gemm_swizzled_scales
=
tensor
.
_with_gemm_swizzled_scales
,
)
# Default case
...
...
@@ -588,6 +615,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
fp4_dtype
:
TE_DType
,
dtype
:
torch
.
dtype
,
quantizer
:
Quantizer
,
with_gemm_swizzled_scales
:
bool
=
False
,
)
->
NVFP4Tensor
:
"""Build NVFP4Tensor, for use in __reduce__
...
...
@@ -607,6 +635,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
amax_columnwise
=
amax_columnwise
,
quantizer
=
quantizer
,
requires_grad
=
False
,
with_gemm_swizzled_scales
=
with_gemm_swizzled_scales
,
)
def
__reduce_ex__
(
self
,
protocol
:
int
)
->
tuple
:
...
...
@@ -624,6 +653,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
self
.
_fp4_dtype
,
self
.
dtype
,
self
.
_quantizer
,
self
.
_with_gemm_swizzled_scales
,
),
)
...
...
@@ -674,6 +704,7 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
self
.
_columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
self
.
_amax_rowwise
=
tensor
.
_amax_rowwise
self
.
_amax_columnwise
=
tensor
.
_amax_columnwise
self
.
_with_gemm_swizzled_scales
=
tensor
.
_with_gemm_swizzled_scales
return
# Quantize to FP8
...
...
@@ -760,6 +791,7 @@ class _ViewFunc(torch.autograd.Function):
quantizer
=
tensor
.
_quantizer
,
fp4_dtype
=
tensor
.
_fp4_dtype
,
requires_grad
=
tensor
.
requires_grad
,
with_gemm_swizzled_scales
=
tensor
.
_with_gemm_swizzled_scales
,
)
@
staticmethod
...
...
@@ -801,6 +833,7 @@ class _ViewFunc(torch.autograd.Function):
quantizer
=
grad
.
_quantizer
,
fp4_dtype
=
grad
.
_fp4_dtype
,
requires_grad
=
grad
.
requires_grad
,
with_gemm_swizzled_scales
=
grad
.
_with_gemm_swizzled_scales
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
...
...
@@ -880,6 +913,7 @@ class _ReshapeFunc(torch.autograd.Function):
quantizer
=
tensor
.
_quantizer
,
fp4_dtype
=
tensor
.
_fp4_dtype
,
requires_grad
=
tensor
.
requires_grad
,
with_gemm_swizzled_scales
=
tensor
.
_with_gemm_swizzled_scales
,
)
@
staticmethod
...
...
@@ -921,6 +955,7 @@ class _ReshapeFunc(torch.autograd.Function):
quantizer
=
grad
.
_quantizer
,
fp4_dtype
=
grad
.
_fp4_dtype
,
requires_grad
=
grad
.
requires_grad
,
with_gemm_swizzled_scales
=
grad
.
_with_gemm_swizzled_scales
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
transformer_engine/pytorch/tensor/storage/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Storage for quantized tensors."""
...
...
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -11,7 +11,6 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
Float8BlockScaleTensorFormat
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
...quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
...
...
@@ -37,7 +36,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
_rowwise_scale_inv
:
Optional
[
torch
.
Tensor
]
_columnwise_scale_inv
:
Optional
[
torch
.
Tensor
]
_is_2D_scaled
:
bool
_data_format
:
Float8BlockScaleTensorFormat
def
__new__
(
cls
,
...
...
@@ -48,7 +46,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
fp8_dtype
:
TE_DType
,
quantizer
:
Quantizer
,
is_2D_scaled
:
bool
,
data_format
:
Float8BlockScaleTensorFormat
,
*
args
,
**
kwargs
,
):
...
...
@@ -63,7 +60,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
instance
.
_rowwise_scale_inv
=
rowwise_scale_inv
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
instance
.
_is_2D_scaled
=
is_2D_scaled
instance
.
_data_format
=
data_format
return
instance
...
...
@@ -88,13 +84,8 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
"fp8_dtype"
:
self
.
_fp8_dtype
,
"quantizer"
:
self
.
_quantizer
,
"is_2D_scaled"
:
self
.
_is_2D_scaled
,
"data_format"
:
self
.
_data_format
,
}
def
_is_gemm_ready_format
(
self
)
->
bool
:
"""Whether data is in GEMM_READY format"""
return
self
.
_data_format
==
Float8BlockScaleTensorFormat
.
GEMM_READY
def
prepare_for_saving
(
self
,
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
Float8BlockwiseQTensorStorage
]:
...
...
@@ -154,36 +145,18 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
for
i
in
range
(
len
(
q
.
shape
)
-
1
):
q_M
*=
q
.
shape
[
i
]
inner_q_dimension_tiled
=
True
if
self
.
_is_gemm_ready_format
():
scales_tiled_dim
,
scales_untiled_dim
=
scale_inv
.
shape
inner_scale_dimension_tiled
=
False
scales_are_compact
=
False
else
:
scales_untiled_dim
,
scales_tiled_dim
=
scale_inv
.
shape
inner_scale_dimension_tiled
=
True
scales_are_compact
=
True
scales_tiled_dim
,
scales_untiled_dim
=
scale_inv
.
shape
else
:
assert
self
.
_columnwise_data
is
not
None
,
"No data to dequantize"
q
=
self
.
_columnwise_data
scale_inv
=
self
.
_columnwise_scale_inv
scales_tiled_dim
,
scales_untiled_dim
=
scale_inv
.
shape
inner_scale_dimension_tiled
=
False
if
self
.
_is_gemm_ready_format
():
inner_q_dimension_tiled
=
True
transpose_output
=
True
if
len
(
q
.
shape
)
>=
1
:
q_M
=
q
.
shape
[
0
]
for
i
in
range
(
1
,
len
(
q
.
shape
)):
q_K
*=
q
.
shape
[
i
]
scales_are_compact
=
False
else
:
inner_q_dimension_tiled
=
False
transpose_output
=
False
if
len
(
q
.
shape
)
>=
1
:
q_K
=
q
.
shape
[
-
1
]
for
i
in
range
(
len
(
q
.
shape
)
-
1
):
q_M
*=
q
.
shape
[
i
]
scales_are_compact
=
True
inner_q_dimension_tiled
=
True
transpose_output
=
True
if
len
(
q
.
shape
)
>=
1
:
q_M
=
q
.
shape
[
0
]
for
i
in
range
(
1
,
len
(
q
.
shape
)):
q_K
*=
q
.
shape
[
i
]
orig_shape
=
q
.
shape
q
=
q
.
reshape
(
q_M
,
q_K
)
...
...
@@ -203,15 +176,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
).
contiguous
()
padded_M
,
padded_K
=
q
.
shape
q_tiled
=
q
.
reshape
(
scales_tiled_dim
,
block_len
,
q_K
)
if
not
scales_are_compact
and
scales_untiled_dim
>
q_M
:
if
scales_untiled_dim
>
q_M
:
# untiled scale dimension is 4 element aligned.
scale_inv
=
scale_inv
[:,
:
q_M
].
contiguous
()
if
scales_are_compact
and
inner_scale_dimension_tiled
:
dq_scale
=
scale_inv
.
contiguous
().
reshape
(
q_M
,
scales_tiled_dim
,
1
)
elif
scales_are_compact
and
not
inner_scale_dimension_tiled
:
dq_scale
=
scale_inv
.
contiguous
().
reshape
(
scales_tiled_dim
,
1
,
q_K
)
else
:
dq_scale
=
scale_inv
.
transpose
(
-
2
,
-
1
).
contiguous
().
reshape
(
q_M
,
scales_tiled_dim
,
1
)
dq_scale
=
scale_inv
.
transpose
(
-
2
,
-
1
).
contiguous
().
reshape
(
q_M
,
scales_tiled_dim
,
1
)
torch_q_dtype
=
TE_DType_To_Torch
[
self
.
_fp8_dtype
]
result
=
q_tiled
.
view
(
torch_q_dtype
).
to
(
torch
.
float32
)
*
dq_scale
if
padded_M
!=
q_M
or
padded_K
!=
q_K
:
...
...
@@ -234,12 +202,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if
not
self
.
_is_2D_scaled
:
return
self
.
_dequantize_vectorwise
(
dtype
=
dtype
)
if
not
self
.
_is_gemm_ready_format
():
raise
NotImplementedError
(
"Dequantize is only supported with GEMM_READY data format, "
f
"but found _data_format=
{
self
.
_data_format
}
"
)
def
format_scale_as_logical_shape
(
q_K
,
scales
,
block_len
):
# The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape
=
math
.
ceil
(
q_K
/
block_len
)
...
...
@@ -305,8 +267,6 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if
self
.
_rowwise_data
is
not
None
:
return
self
.
_rowwise_data
.
size
(
*
args
,
**
kwargs
)
dims
=
list
(
self
.
_columnwise_data
.
size
(
*
args
,
**
kwargs
))
if
not
self
.
_is_gemm_ready_format
():
# compact format
return
torch
.
Size
(
dims
)
reordered
=
[]
for
i
in
range
(
1
,
len
(
dims
)):
reordered
.
append
(
dims
[
i
])
...
...
@@ -367,7 +327,7 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return
(
"Float8BlockwiseQTensorStorage("
f
"fp8_dtype=
{
self
.
_fp8_dtype
}
, "
f
"
{
descriptor
}
_scaled_data=
{
data
}
"
f
"
{
descriptor
}
_scaled_data=
{
data
}
)
"
)
def
update_usage
(
...
...
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -57,13 +57,23 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
"""
# Row-scaled FP8 data
_rowwise_data
:
Optional
[
torch
.
Tensor
]
# Column-scaled FP8 data
_columnwise_data
:
Optional
[
torch
.
Tensor
]
_quantizer
:
Optional
[
Quantizer
]
_fp8_dtype
:
TE_DType
# Scaling factors for row-scaled FP8 data
_rowwise_scale_inv
:
torch
.
Tensor
# Scaling factors for column-scaled FP8 data
_columnwise_scale_inv
:
torch
.
Tensor
# Builder class for casting to MXFP8
_quantizer
:
Optional
[
Quantizer
]
# FP8 data type
_fp8_dtype
:
TE_DType
# Whether scaling factors are in the swizzled format expected by
# GEMM
_with_gemm_swizzled_scales
:
bool
def
__new__
(
cls
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
...
...
@@ -72,6 +82,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
columnwise_scale_inv
:
Optional
[
torch
.
Tensor
],
fp8_dtype
:
TE_DType
,
quantizer
:
Optional
[
Quantizer
],
with_gemm_swizzled_scales
:
bool
,
*
args
,
**
kwargs
,
):
...
...
@@ -81,10 +92,11 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
instance
=
super
().
__new__
(
cls
,
*
args
,
**
kwargs
)
instance
.
_rowwise_data
=
rowwise_data
instance
.
_columnwise_data
=
columnwise_data
instance
.
_quantizer
=
quantizer
.
copy
()
if
quantizer
is
not
None
else
None
instance
.
_fp8_dtype
=
fp8_dtype
instance
.
_rowwise_scale_inv
=
rowwise_scale_inv
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
instance
.
_quantizer
=
quantizer
.
copy
()
if
quantizer
is
not
None
else
None
instance
.
_fp8_dtype
=
fp8_dtype
instance
.
_with_gemm_swizzled_scales
=
with_gemm_swizzled_scales
return
instance
...
...
@@ -108,6 +120,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
"columnwise_scale_inv"
:
self
.
_columnwise_scale_inv
,
"fp8_dtype"
:
self
.
_fp8_dtype
,
"quantizer"
:
self
.
_quantizer
,
"with_gemm_swizzled_scales"
:
self
.
_with_gemm_swizzled_scales
,
}
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
MXFP8TensorStorage
]:
...
...
@@ -197,6 +210,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
columnwise_scale_inv
=
self
.
_columnwise_scale_inv
,
fp8_dtype
=
self
.
_fp8_dtype
,
quantizer
=
self
.
_quantizer
,
with_gemm_swizzled_scales
=
self
.
_with_gemm_swizzled_scales
,
)
def
__repr__
(
self
):
...
...
@@ -255,7 +269,7 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
self
.
_columnwise_data
=
None
self
.
_columnwise_scale_inv
=
None
def
get_usages
(
self
)
->
Tuple
[
bool
,
bool
]:
def
get_usages
(
self
)
->
Dict
[
str
,
bool
]:
"""Get the usage of the tensor"""
return
{
"rowwise"
:
self
.
_rowwise_data
is
not
None
,
...
...
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -71,15 +71,29 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
"""
# Row-scaled FP4 data
_rowwise_data
:
Optional
[
torch
.
Tensor
]
# Column-scaled FP4 data
_columnwise_data
:
Optional
[
torch
.
Tensor
]
_quantizer
:
Optional
[
Quantizer
]
# Block scaling factors for row-scaled FP4 data
_rowwise_scale_inv
:
torch
.
Tensor
# Block scaling factors for column-scaled FP4 data
_columnwise_scale_inv
:
torch
.
Tensor
_fp4_dtype
:
TE_DType
# Input absolute maximum value (used to compute tensor scale for
# row-scaled FP4 data)
_amax_rowwise
:
torch
.
Tensor
# Input absolute maximum value (used to compute tensor scale for
# column-scaled FP4 data)
_amax_columnwise
:
torch
.
Tensor
# Builder class for casting to MXFP8
_quantizer
:
Optional
[
Quantizer
]
# FP4 data type
_fp4_dtype
:
TE_DType
# Whether scaling factors are in the swizzled format expected by
# GEMM
_with_gemm_swizzled_scales
:
bool
def
__new__
(
cls
,
rowwise_data
:
Optional
[
torch
.
Tensor
],
...
...
@@ -90,6 +104,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
amax_columnwise
:
torch
.
Tensor
,
fp4_dtype
:
TE_DType
,
quantizer
:
Optional
[
Quantizer
],
with_gemm_swizzled_scales
:
bool
,
*
args
,
**
kwargs
,
):
...
...
@@ -104,6 +119,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
instance
.
_columnwise_scale_inv
=
columnwise_scale_inv
instance
.
_amax_rowwise
=
amax_rowwise
instance
.
_amax_columnwise
=
amax_columnwise
instance
.
_with_gemm_swizzled_scales
=
with_gemm_swizzled_scales
return
instance
...
...
@@ -131,6 +147,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
"amax_columnwise"
:
self
.
_amax_columnwise
,
"fp4_dtype"
:
self
.
_fp4_dtype
,
"quantizer"
:
self
.
_quantizer
,
"with_gemm_swizzled_scales"
:
self
.
_with_gemm_swizzled_scales
,
}
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
NVFP4TensorStorage
]:
...
...
@@ -248,6 +265,7 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
amax_columnwise
=
self
.
_amax_columnwise
,
quantizer
=
self
.
_quantizer
,
fp4_dtype
=
self
.
_fp4_dtype
,
with_gemm_swizzled_scales
=
self
.
_with_gemm_swizzled_scales
,
)
def
__repr__
(
self
):
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -8,7 +8,11 @@ from typing import Optional, Union, List
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
multi_tensor_scale
,
multi_tensor_compute_scale_and_scale_inv
from
transformer_engine_torch
import
(
multi_tensor_scale
,
multi_tensor_compute_scale_and_scale_inv
,
multi_tensor_compute_scale_inv_e8m0
,
)
from
..quantized_tensor
import
QuantizedTensor
,
Quantizer
,
QuantizedTensorStorage
from
.float8_tensor
import
Float8Tensor
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
...
...
@@ -74,7 +78,7 @@ def cast_master_weights_to_fp8(
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights.
manual_post_all_gather_processing: bool, default = `False`.
manual_post_all_gather_processing
: bool, default = `False`.
If False, post processing will be automatically triggered during next forward.
If True, the timing of calling post_all_gather_processing is left to the user.
Note that users must call `post_all_gather_processing` if it's set to True,
...
...
@@ -85,6 +89,7 @@ def cast_master_weights_to_fp8(
delayed_scaling_params
=
[]
current_scaling_params
=
[]
blockwise_scaling_params
=
[]
mxfp8_scaling_params
=
[]
if
fsdp_shard_model_weights
is
None
:
use_fsdp_shard_model_weights
=
False
...
...
@@ -131,8 +136,8 @@ def cast_master_weights_to_fp8(
(
model_weight
,
master_weight
,
start_offset
,
fsdp_shard_model_weight
)
)
elif
isinstance
(
quantizer
,
MXFP8Quantizer
):
raise
NotImplementedError
(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
mxfp8_scaling_params
.
append
(
(
model_weight
,
master_weight
,
start_offset
,
fsdp_shard_model_weight
)
)
else
:
raise
ValueError
(
...
...
@@ -146,6 +151,8 @@ def cast_master_weights_to_fp8(
_cast_master_weights_to_fp8_current_scaling
(
current_scaling_params
,
*
extra_args
)
if
len
(
blockwise_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_blockwise_scaling
(
blockwise_scaling_params
,
*
extra_args
)
if
len
(
mxfp8_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_mxfp8_scaling
(
mxfp8_scaling_params
,
*
extra_args
)
def
_cast_master_weights_to_fp8_delayed_scaling
(
...
...
@@ -471,6 +478,131 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
)
def
_cast_master_weights_to_fp8_mxfp8_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
,
manual_post_all_gather_processing
=
False
):
# pylint: disable=unused-argument
r
"""Helper function to cast master weights to FP8 primary weights for mxfp8 scaling.
Parameters
----------
params : List of tuple, each tuple contains a model weight, a master weight, and an offset
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
"""
# Parameter attributes
device
=
params
[
0
][
0
].
device
for
_
,
master_weight
,
_
,
_
in
params
:
if
master_weight
is
not
None
:
master_weight_dtype
=
master_weight
.
dtype
break
# Get the total number of amax elements in all the model weights.
cu_rowwise_amax_sizes
=
[
0
]
cu_colwise_amax_sizes
=
[
0
]
for
model_weight
,
_
,
_
,
_
in
params
:
rowwise_shape
=
model_weight
.
_rowwise_scale_inv
.
shape
assert
len
(
rowwise_shape
)
==
2
colwise_shape
=
model_weight
.
_columnwise_scale_inv
.
shape
assert
len
(
colwise_shape
)
==
2
cu_rowwise_amax_sizes
.
append
(
cu_rowwise_amax_sizes
[
-
1
]
+
rowwise_shape
[
0
]
*
rowwise_shape
[
1
]
)
cu_colwise_amax_sizes
.
append
(
cu_colwise_amax_sizes
[
-
1
]
+
colwise_shape
[
0
]
*
colwise_shape
[
1
]
)
# Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce
# NCCL kernels at once.
packed_amaxes
=
torch
.
zeros
(
cu_rowwise_amax_sizes
[
-
1
]
+
cu_colwise_amax_sizes
[
-
1
],
dtype
=
master_weight_dtype
,
device
=
device
,
)
# ---------------------------------------------------------------------------------------------
# Step 1: Iterate through all the none empty master weights and compute amax of them. Store the
# amaxes in a contiguous buffer. If a block of a master weight is empty, the
# corresponding amax will be set to 0.
# ---------------------------------------------------------------------------------------------
amaxes_rowwise
,
scale_invs_rowwise
=
[],
[]
amaxes_colwise
,
scale_invs_colwise
=
[],
[]
for
i
,
(
model_weight
,
master_weight
,
start_offset
,
_
)
in
enumerate
(
params
):
rowwise_shape
=
model_weight
.
_rowwise_scale_inv
.
shape
colwise_shape
=
model_weight
.
_columnwise_scale_inv
.
shape
rowwise_start
=
cu_rowwise_amax_sizes
[
i
]
rowwise_end
=
cu_rowwise_amax_sizes
[
i
+
1
]
colwise_start
=
cu_rowwise_amax_sizes
[
-
1
]
+
cu_colwise_amax_sizes
[
i
]
colwise_end
=
cu_rowwise_amax_sizes
[
-
1
]
+
cu_colwise_amax_sizes
[
i
+
1
]
amax_rowwise
=
packed_amaxes
[
rowwise_start
:
rowwise_end
].
reshape
(
rowwise_shape
)
amax_colwise
=
packed_amaxes
[
colwise_start
:
colwise_end
].
reshape
(
colwise_shape
)
amaxes_rowwise
.
append
(
amax_rowwise
)
amaxes_colwise
.
append
(
amax_colwise
)
scale_invs_rowwise
.
append
(
model_weight
.
_rowwise_scale_inv
)
scale_invs_colwise
.
append
(
model_weight
.
_columnwise_scale_inv
)
# Compute amax of the master weight and store it in packed_amaxes.
if
master_weight
is
not
None
:
assert
len
(
model_weight
.
shape
)
==
2
h
,
w
=
model_weight
.
shape
tex
.
mxfp8_scaling_compute_partial_amax
(
master_weight
,
amax_rowwise
,
amax_colwise
,
h
,
w
,
start_offset
)
# ---------------------------------------------------------------------------------------------
# Step 2: Perform all-reduce on packed_amaxes to get the global amax.
# ---------------------------------------------------------------------------------------------
torch
.
distributed
.
all_reduce
(
packed_amaxes
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
group
)
# ---------------------------------------------------------------------------------------------
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
multi_tensor_applier
(
multi_tensor_compute_scale_inv_e8m0
,
None
,
# dummy_overflow_buf
[
amaxes_rowwise
+
amaxes_colwise
,
scale_invs_rowwise
+
scale_invs_colwise
,
],
)
# ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8.
# ---------------------------------------------------------------------------------------------
for
(
(
model_weight
,
master_weight
,
start_offset
,
model_weight_fragment
),
scale_inv_rowwise
,
scale_inv_colwise
,
)
in
zip
(
params
,
scale_invs_rowwise
,
scale_invs_colwise
):
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if
master_weight
is
None
:
continue
# Cast master weight to FP8
end_offset
=
start_offset
+
master_weight
.
numel
()
if
use_fsdp_shard_model_weights
:
rowwise_fragment
=
model_weight_fragment
[
0
]
colwise_fragment
=
model_weight_fragment
[
1
]
else
:
rowwise_fragment
=
model_weight
.
_rowwise_data
.
reshape
(
-
1
)[
start_offset
:
end_offset
]
colwise_fragment
=
model_weight
.
_columnwise_data
.
reshape
(
-
1
)[
start_offset
:
end_offset
]
assert
len
(
model_weight
.
shape
)
==
2
h
,
w
=
model_weight
.
shape
tex
.
mxfp8_scaling_partial_cast
(
master_weight
,
rowwise_fragment
,
colwise_fragment
,
scale_inv_rowwise
,
scale_inv_colwise
,
h
,
w
,
start_offset
,
)
def
post_all_gather_processing
(
model_weights
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]):
"""
Post-processing after all-gather for weights in distributed optimizer.
...
...
@@ -489,6 +621,9 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten
elif
isinstance
(
model_weight
,
Float8BlockwiseQTensor
):
# Blockwise scaling: create column-wise storage.
model_weight
.
_create_columnwise
()
elif
isinstance
(
model_weight
,
MXFP8Tensor
):
# MXFP8 scaling: no need to do anything.
pass
elif
isinstance
(
model_weight
,
QuantizedTensor
):
raise
ValueError
(
f
"post_processing for
{
type
(
model_weight
)
}
is not supported"
)
...
...
transformer_engine/pytorch/torch_version.py
0 → 100644
View file @
0d874a4e
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""PyTorch version utilities"""
from
__future__
import
annotations
import
functools
import
torch
from
packaging.version
import
Version
as
PkgVersion
@
functools
.
lru_cache
(
maxsize
=
None
)
def
torch_version
()
->
tuple
[
int
,
...]:
"""Get PyTorch version"""
return
PkgVersion
(
str
(
torch
.
__version__
)).
release
Prev
1
…
28
29
30
31
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment