Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c1a1c04e
Commit
c1a1c04e
authored
Dec 27, 2025
by
wenjh
Browse files
Merge nv_main(2.10) to main
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
e698a0a7
66aed3ae
Changes
208
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1027 additions
and
267 deletions
+1027
-267
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+36
-49
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+83
-28
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+24
-23
transformer_engine/pytorch/ops/_common.py
transformer_engine/pytorch/ops/_common.py
+1
-1
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+1
-1
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+1
-1
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+1
-1
transformer_engine/pytorch/optimizers/fused_adam.py
transformer_engine/pytorch/optimizers/fused_adam.py
+2
-2
transformer_engine/pytorch/permutation.py
transformer_engine/pytorch/permutation.py
+1
-1
transformer_engine/pytorch/quantized_tensor.py
transformer_engine/pytorch/quantized_tensor.py
+105
-77
transformer_engine/pytorch/setup.py
transformer_engine/pytorch/setup.py
+12
-2
transformer_engine/pytorch/tensor/__init__.py
transformer_engine/pytorch/tensor/__init__.py
+1
-1
transformer_engine/pytorch/tensor/_quantization_helpers.py
transformer_engine/pytorch/tensor/_quantization_helpers.py
+84
-0
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+10
-7
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+243
-28
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+355
-19
transformer_engine/pytorch/tensor/nvfp4_tensor.py
transformer_engine/pytorch/tensor/nvfp4_tensor.py
+41
-17
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py
...pytorch/tensor/storage/float8_blockwise_tensor_storage.py
+8
-3
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
...er_engine/pytorch/tensor/storage/float8_tensor_storage.py
+10
-3
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
...mer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
+8
-3
No files found.
transformer_engine/pytorch/module/layernorm_linear.py
View file @
c1a1c04e
...
...
@@ -16,7 +16,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_
experimental
from
transformer_engine.pytorch.tensor.utils
import
is_
custom
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
...
...
@@ -57,7 +57,7 @@ from ..constants import GemmParallelModes, dist_group_type
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
._common
import
apply_normalization
,
noop_cat
,
WeightGradStore
from
..
tensor.
quantized_tensor
import
(
from
..quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensorStorage
,
Quantizer
,
...
...
@@ -67,10 +67,15 @@ from ..tensor.quantized_tensor import (
from
...debug.pytorch.debug_state
import
TEDebugState
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..cpu_offload
import
(
is_cpu_offload_enabled
,
start_offload
,
mark_not_offload
,
mark_activation_offload
,
)
from
..tensor.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
..tensor.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpp_extensions
import
(
general_gemm
,
...
...
@@ -167,6 +172,9 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
if
is_cpu_offload_enabled
():
start_offload
(
inputmat
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
weight_requires_grad
=
weight
.
requires_grad
...
...
@@ -203,13 +211,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
experimental
=
is_experimental
(
input_quantizer
)
custom
=
is_custom
(
input_quantizer
)
with_quantized_norm
=
(
fp8
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
experimental
# TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
and
not
custom
# TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
)
# Apply normalization
...
...
@@ -255,8 +263,8 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer
=
None
if
fp8
or
debug
:
quantizer
=
input_quantizer
#
experimental
recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
experimental
:
#
custom
recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
custom
:
ln_out
=
quantizer
(
ln_out
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
# Initialize Userbuffers all-gather
...
...
@@ -285,12 +293,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Prepare weight tensor
# ------------------------------------------------------
weightmat
=
weight
quantized
_weight
=
False
is_weight_param_
quantized
=
False
if
fp8
or
debug
:
quantized_weight
=
not
isinstance
(
weight
,
QuantizedTensorStorage
)
is_weight_param_quantized
=
isinstance
(
weight
,
QuantizedTensorStorage
)
# Configure quantizer
if
weight_quantizer
is
not
None
:
# If weight is already quantized, no need to set quantizer states
if
is_weight_param_quantized
:
weight_quantizer
=
weight
.
_quantizer
elif
weight_quantizer
is
not
None
:
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
is_grad_enabled
)
# Get quantized weight
...
...
@@ -422,10 +433,6 @@ class _LayerNormLinear(torch.autograd.Function):
):
ln_out
.
update_usage
(
rowwise_usage
=
False
)
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
weightmat
,
QuantizedTensorStorage
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
mark_activation_offload
(
inputmat
,
mu
,
rsigma
,
ln_out
)
...
...
@@ -438,42 +445,21 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group
,
mu
,
rsigma
,
weightmat
if
quantized
_weight
else
None
,
weightmat
if
fp8
and
not
is_weight_param_
quantized
else
None
,
ln_out
if
weight
.
requires_grad
else
None
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_scatter"
)
# Do not offload weights and biases
weight
.
offloading_activation
=
False
weightmat
.
offloading_activation
=
False
if
bias
is
not
None
:
bias
.
offloading_activation
=
False
ln_weight
.
offloading_activation
=
False
ctx
.
fine_grained_activation_offloading
=
fine_grained_activation_offloading
if
fine_grained_activation_offloading
and
cpu_offloading
:
raise
ValueError
(
f
"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
if
cpu_offloading
:
mark_not_offload
(
weightmat
,
weight
,
bias
,
ln_weight
,
ln_bias
,
)
if
(
fine_grained_activation_offloading
and
weight
.
requires_grad
and
fuse_wgrad_accumulation
):
if
hasattr
(
weight
,
"grad_added_to_main_grad"
):
ctx
.
has_grad_added_to_main_grad
=
True
ctx
.
grad_added_to_main_grad
=
weight
.
grad_added_to_main_grad
weight
.
grad_added_to_main_grad
=
True
ctx
.
weight_object
=
weight
else
:
ctx
.
has_grad_added_to_main_grad
=
False
if
cpu_offloading
or
int
(
os
.
getenv
(
"NVTE_SWAP_OVERLAP_GRAD"
,
"0"
)):
ctx
.
has_grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
has_grad_added_to_main_grad
:
ctx
.
grad_added_to_main_grad
=
hasattr
(
weight
,
"grad_added_to_main_grad"
)
if
ctx
.
grad_added_to_main_grad
:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
...
...
@@ -495,7 +481,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
tensor_objects
=
tensor_objects
ctx
.
requires_dgrad
=
inp_requires_grad
ctx
.
requires_wgrad
=
weight
.
requires_grad
ctx
.
quantized_weight
=
quantized
_weight
ctx
.
is_weight_param_
quantized
=
is
_weight
_param_
quantized
if
fuse_wgrad_accumulation
and
weight
.
requires_grad
:
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
...
...
@@ -579,6 +565,7 @@ class _LayerNormLinear(torch.autograd.Function):
mu
,
rsigma
,
)
=
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx
.
tensor_objects
=
None
...
...
@@ -599,7 +586,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx
.
fsdp_shapes
,
mu
,
rsigma
,
weight
if
ctx
.
fp8
and
ctx
.
quantized
_weight
else
None
,
weight
if
ctx
.
fp8
and
not
ctx
.
is_weight_param_
quantized
else
None
,
ln_out
,
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.fsdp_gather"
)
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
c1a1c04e
...
...
@@ -18,7 +18,7 @@ import transformer_engine_torch as tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.tensor.utils
import
is_
experimental
from
transformer_engine.pytorch.tensor.utils
import
is_
custom
from
.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_workspace
,
...
...
@@ -70,8 +70,13 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from
..tensor.nvfp4_tensor
import
NVFP4Quantizer
from
..tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
._common
import
apply_normalization
,
WeightGradStore
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..tensor.quantized_tensor
import
(
from
..cpu_offload
import
(
is_cpu_offload_enabled
,
start_offload
,
mark_not_offload
,
mark_activation_offload
,
)
from
..quantized_tensor
import
(
QuantizedTensorStorage
,
Quantizer
,
prepare_for_saving
,
...
...
@@ -106,6 +111,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu"
:
(
tex
.
sreglu
,
tex
.
dsreglu
,
None
),
"silu"
:
(
tex
.
silu
,
tex
.
dsilu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"clamped_swiglu"
:
(
tex
.
clamped_swiglu
,
tex
.
clamped_dswiglu
,
None
),
}
if
recipe
.
delayed
()
or
recipe
.
mxfp8
():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
...
...
@@ -121,6 +127,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu"
:
(
tex
.
sreglu
,
tex
.
dsreglu
,
None
),
"silu"
:
(
tex
.
silu
,
tex
.
dsilu
,
tex
.
dbias_dsilu
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"clamped_swiglu"
:
(
tex
.
clamped_swiglu
,
tex
.
clamped_dswiglu
,
None
),
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling or custom quantization: []
...
...
@@ -142,6 +149,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu"
:
(
tex
.
sreglu
,
tex
.
dsreglu
,
None
),
"silu"
:
(
tex
.
silu
,
tex
.
dsilu
,
None
),
"swiglu"
:
(
tex
.
swiglu
,
tex
.
dswiglu
,
None
),
"clamped_swiglu"
:
(
tex
.
clamped_swiglu
,
tex
.
clamped_dswiglu
,
None
),
}
raise
NotImplementedError
(
f
"Unhandled recipe type
{
recipe
}
"
)
...
...
@@ -206,6 +214,7 @@ class _LayerNormMLP(torch.autograd.Function):
bwd_ln_sm_margin
:
int
,
zero_centered_gamma
:
bool
,
activation
:
str
,
activation_params
:
Optional
[
dict
],
normalization
:
str
,
ub_overlap_ag
:
bool
,
ub_overlap_rs
:
bool
,
...
...
@@ -238,6 +247,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight
=
cast_if_needed
(
ln_weight
,
activation_dtype
)
if
ln_bias
is
not
None
:
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
if
is_cpu_offload_enabled
():
start_offload
(
inputmat
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
backwards_needs_fc1_input
=
is_grad_enabled
and
fc1_weight
.
requires_grad
...
...
@@ -275,13 +286,13 @@ class _LayerNormMLP(torch.autograd.Function):
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
experimental
=
is_experimental
(
fc1_input_quantizer
)
custom
=
is_custom
(
fc1_input_quantizer
)
with_quantized_norm
=
(
fp8
and
not
debug
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
and
not
experimental
and
not
custom
)
# Apply normalization
...
...
@@ -321,8 +332,8 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer
=
None
if
fp8
or
debug
:
quantizer
=
fc1_input_quantizer
#
experimental
recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
experimental
:
#
custom
recipe doesn't need to support quantized AG
if
not
with_quantized_norm
and
not
custom
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag
:
...
...
@@ -354,8 +365,17 @@ class _LayerNormMLP(torch.autograd.Function):
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
# No need to set the quantizer states if weights are already quantized
if
isinstance
(
fc1_weight
,
QuantizedTensorStorage
):
fc1_weight_quantizer
=
fc1_weight
.
_quantizer
elif
fc1_weight_quantizer
is
not
None
:
fc1_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
is_grad_enabled
)
if
isinstance
(
fc2_weight
,
QuantizedTensorStorage
):
fc2_weight_quantizer
=
fc2_weight
.
_quantizer
elif
fc2_weight_quantizer
is
not
None
:
fc2_weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
is_grad_enabled
)
fc1_weight_final
=
module
.
get_weight_workspace
(
tensor
=
fc1_weight
,
quantizer
=
fc1_weight_quantizer
,
...
...
@@ -447,6 +467,7 @@ class _LayerNormMLP(torch.autograd.Function):
# ACTIVATION - sometimes activation is fused with the GEMM above.
fc1_out_without_bias
=
None
act_params
=
activation_params
or
{}
if
bias_gelu_fusion
:
fc1_out
=
None
...
...
@@ -456,7 +477,7 @@ class _LayerNormMLP(torch.autograd.Function):
act_out
,
_
,
fc1_out
,
_
=
fc1_outputs
elif
debug
:
fc1_out
,
*
_
=
fc1_outputs
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
activation_func
(
fc1_out
,
None
,
**
act_params
)
act_out
=
fc2_input_quantizer
(
act_out
)
else
:
fc1_out
,
*
_
=
fc1_outputs
...
...
@@ -464,19 +485,19 @@ class _LayerNormMLP(torch.autograd.Function):
recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
recipe
.
float8_block_scaling
():
# tex.quantize does not support GELU fusion for blockwise
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
activation_func
(
fc1_out
,
None
,
**
act_params
)
act_out
=
tex
.
quantize
(
act_out
,
fc2_input_quantizer
)
elif
recipe
.
custom
():
# tex.quantize does not support custom quantizers
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
activation_func
(
fc1_out
,
None
,
**
act_params
)
act_out
=
fc2_input_quantizer
(
act_out
)
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
)
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
,
**
act_params
)
else
:
if
fp8_calibration
:
act_out
=
activation_func
(
fc1_out
,
None
)
act_out
=
activation_func
(
fc1_out
,
None
,
**
act_params
)
else
:
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
)
act_out
=
activation_func
(
fc1_out
,
fc2_input_quantizer
,
**
act_params
)
if
not
is_grad_enabled
:
clear_tensor_data
(
fc1_out
)
...
...
@@ -540,13 +561,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Cache state for backward pass
if
is_grad_enabled
:
# Weight with column-wise usage is needed for dgrad GEMM.
if
isinstance
(
fc1_weight_final
,
QuantizedTensorStorage
):
fc1_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
fc2_weight_final
,
QuantizedTensorStorage
):
fc2_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
mark_activation_offload
(
inputmat
,
mu
,
rsigma
,
ln_out
,
fc1_out
,
fc1_out_without_bias
,
act_out
...
...
@@ -577,6 +591,18 @@ class _LayerNormMLP(torch.autograd.Function):
clear_tensor_data
(
act_out
)
act_out
=
None
if
cpu_offloading
:
mark_not_offload
(
ln_weight
,
ln_bias
,
fc1_weight_final
,
fc1_weight
,
fc1_bias
,
fc2_weight_final
,
fc2_weight
,
fc2_bias
,
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
inputmat
,
ln_weight
,
...
...
@@ -631,6 +657,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
device
=
device
ctx
.
activation_dtype
=
activation_dtype
ctx
.
activation
=
activation
ctx
.
activation_params
=
activation_params
ctx
.
fp8
=
fp8
ctx
.
fp8_recipe
=
FP8GlobalStateManager
.
get_fp8_recipe
()
if
fp8
else
None
ctx
.
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
...
...
@@ -1017,6 +1044,7 @@ class _LayerNormMLP(torch.autograd.Function):
# --------------------------------------------------
# bias computation
act_params
=
ctx
.
activation_params
or
{}
fc1_bias_grad
=
None
fuse_gemm_and_bias_fc1_wgrad
=
False
if
ctx
.
fc1_grad_output_quantizer
is
not
None
:
...
...
@@ -1030,7 +1058,7 @@ class _LayerNormMLP(torch.autograd.Function):
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
elif
ctx
.
debug
:
dact_func
=
_act_func
(
ctx
.
activation
)[
1
]
dact
=
dact_func
(
fc2_dgrad
,
fc1_out
.
to
(
ctx
.
activation_dtype
),
None
)
dact
=
dact_func
(
fc2_dgrad
,
fc1_out
.
to
(
ctx
.
activation_dtype
),
None
,
**
act_params
)
fc1_bias_grad
=
dact
.
sum
(
dim
=
0
)
dact
=
ctx
.
fc1_grad_output_quantizer
(
dact
)
elif
(
...
...
@@ -1042,7 +1070,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
activation
,
ctx
.
fp8_recipe
if
ctx
.
fp8
else
None
)[
2
]
fc1_bias_grad
,
dact
=
dbias_dact_quantize_func
(
fc2_dgrad
,
fc1_out
.
to
(
ctx
.
activation_dtype
),
ctx
.
fc1_grad_output_quantizer
fc2_dgrad
,
fc1_out
.
to
(
ctx
.
activation_dtype
),
ctx
.
fc1_grad_output_quantizer
,
**
act_params
,
)
# quantize bgrad gelu fused
else
:
# Fusion: gemm + gelu,
...
...
@@ -1051,7 +1082,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
activation
,
ctx
.
fp8_recipe
if
ctx
.
fp8
else
None
)[
1
]
dact
=
activation_func_bwd
(
fc2_dgrad
,
fc1_out
.
to
(
ctx
.
activation_dtype
),
None
fc2_dgrad
,
fc1_out
.
to
(
ctx
.
activation_dtype
),
None
,
**
act_params
)
# activation in high precision
if
ctx
.
fp8
:
...
...
@@ -1429,6 +1460,7 @@ class _LayerNormMLP(torch.autograd.Function):
None
,
# bwd_ln_sm_margin
None
,
# zero_centered_gamma
None
,
# activation
None
,
# activation_params
None
,
# normalization
None
,
# ub_overlap_ag
None
,
# ub_overlap_rs
...
...
@@ -1464,7 +1496,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
'silu', 'swiglu', and 'clamped_swiglu'.
activation_params : dict, default = `None`
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
supports 'limit' and 'alpha' parameters.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
...
...
@@ -1565,6 +1601,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
bias
:
bool
=
True
,
normalization
:
str
=
"LayerNorm"
,
activation
:
str
=
"gelu"
,
activation_params
:
Optional
[
dict
]
=
None
,
output_layer_init_method
:
Optional
[
Callable
]
=
None
,
fuse_wgrad_accumulation
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
...
...
@@ -1592,6 +1629,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
assert
normalization
in
[
"LayerNorm"
,
"RMSNorm"
],
"Unsupported normalization type!"
self
.
use_bias
=
bias
self
.
activation
=
activation
self
.
activation_params
=
activation_params
self
.
return_bias
=
return_bias
self
.
apply_bias
=
bias
and
not
return_bias
self
.
return_layernorm_output
=
return_layernorm_output
...
...
@@ -1671,7 +1709,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
layer_norm_bias
=
None
# FC1 init
if
self
.
activation
in
[
"geglu"
,
"qgeglu"
,
"reglu"
,
"sreglu"
,
"swiglu"
]:
if
self
.
activation
in
[
"geglu"
,
"qgeglu"
,
"reglu"
,
"sreglu"
,
"swiglu"
,
"clamped_swiglu"
]:
fc1_output_features
=
2
*
self
.
size_per_partition
else
:
fc1_output_features
=
self
.
size_per_partition
...
...
@@ -1926,6 +1964,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
bwd_ln_sm_margin
,
self
.
zero_centered_gamma
,
self
.
activation
,
self
.
activation_params
,
self
.
normalization
,
self
.
ub_overlap_ag
,
self
.
ub_overlap_rs
,
...
...
@@ -2055,6 +2094,19 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_out
=
onnx_gemm
(
fc1_weight
,
ln_out
,
fc1_bias
)
fc1_out
=
fc1_out
.
to
(
torch
.
float32
)
# activation is computed in fp32
act_params
=
self
.
activation_params
or
{}
# Default params for clamped_swiglu in Transformer Engine
clamped_swiglu_limit
,
clamped_swiglu_alpha
=
act_params
.
get
(
"limit"
,
7.0
),
act_params
.
get
(
"alpha"
,
1.702
)
def
_clamped_swiglu
(
x
,
limit
,
alpha
):
x_glu
,
x_linear
=
x
.
chunk
(
2
,
dim
=-
1
)
x_glu
=
x_glu
.
clamp
(
min
=
None
,
max
=
limit
)
x_linear
=
x_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_glu
=
x_glu
*
torch
.
sigmoid
(
alpha
*
x_glu
)
y
=
out_glu
*
(
x_linear
+
1
)
return
y
activation_map
=
{
"gelu"
:
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
),
...
...
@@ -2069,6 +2121,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
*
x
.
chunk
(
2
,
-
1
)[
1
],
"silu"
:
torch
.
nn
.
functional
.
silu
,
"swiglu"
:
lambda
x
:
torch
.
nn
.
functional
.
silu
(
x
.
chunk
(
2
,
-
1
)[
0
])
*
x
.
chunk
(
2
,
-
1
)[
1
],
"clamped_swiglu"
:
lambda
x
:
_clamped_swiglu
(
x
,
clamped_swiglu_limit
,
clamped_swiglu_alpha
),
}
if
self
.
activation
not
in
activation_map
:
raise
ValueError
(
f
"Unsupported activation in onnx export:
{
self
.
activation
}
"
)
...
...
@@ -2240,7 +2295,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
if
not
self
.
need_backward_dw
():
return
with
torch
.
cuda
.
nvtx
.
range
(
"_LayerNormMLP_wgrad"
):
(
fc2_wgrad
,
fc2_bias_grad_
,
*
_
),
tensor_list_fc2
=
self
.
wgrad_store
.
pop
()
...
...
transformer_engine/pytorch/module/linear.py
View file @
c1a1c04e
...
...
@@ -59,7 +59,7 @@ from ..cpp_extensions import (
from
..constants
import
GemmParallelModes
,
dist_group_type
from
..jit
import
no_torch_dynamo
from
..graph
import
is_graph_capturing
from
..
tensor.
quantized_tensor
import
(
from
..quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensorStorage
,
Quantizer
,
...
...
@@ -68,9 +68,14 @@ from ..tensor.quantized_tensor import (
)
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
,
Float8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.utils
import
is_
experimental
from
..tensor.utils
import
is_
custom
from
..export
import
is_in_onnx_export_mode
,
assert_warmed_up
from
..cpu_offload
import
is_cpu_offload_enabled
,
mark_activation_offload
from
..cpu_offload
import
(
is_cpu_offload_enabled
,
start_offload
,
mark_not_offload
,
mark_activation_offload
,
)
from
...debug.pytorch.debug_state
import
TEDebugState
__all__
=
[
"Linear"
]
...
...
@@ -156,8 +161,8 @@ class _Linear(torch.autograd.Function):
ub_obj
=
get_ub
(
ub_name
+
"_fprop"
,
fp8
)
ub_type
=
tex
.
CommOverlapType
.
AG
#
experimental
recipe check
experimental
=
is_experimental
(
input_quantizer
)
or
is_
experimental
(
weight_quantizer
)
#
custom
recipe check
custom
=
is_custom
(
input_quantizer
)
or
is_
custom
(
weight_quantizer
)
# ------------------------------------------------------
# Prepare input tensor
...
...
@@ -181,7 +186,7 @@ class _Linear(torch.autograd.Function):
if
fp8
or
debug
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
not
isinstance
(
inputmat
,
QuantizedTensorStorage
)
and
not
experimental
:
if
not
isinstance
(
inputmat
,
QuantizedTensorStorage
)
and
not
custom
:
own_quantized_input
=
True
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
if
isinstance
(
...
...
@@ -232,6 +237,9 @@ class _Linear(torch.autograd.Function):
else
:
inputmat
=
cast_if_needed
(
inp
,
activation_dtype
)
# Cast for AMP
inputmat_total
=
inputmat
if
is_cpu_offload_enabled
():
start_offload
(
inputmat
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.input_cast_comm"
)
# ------------------------------------------------------
# Input tensor is ready for GEMM...
...
...
@@ -243,7 +251,8 @@ class _Linear(torch.autograd.Function):
weightmat
=
weight
if
fp8
or
debug
:
# Configure quantizer
if
weight_quantizer
is
not
None
:
# No need to set the quantizer states if weight is already quantized
if
weight_quantizer
is
not
None
and
not
isinstance
(
weight
,
QuantizedTensor
):
columnwise_usage
=
is_grad_enabled
and
inp
.
requires_grad
if
not
columnwise_usage
:
columnwise_usage
=
(
...
...
@@ -251,7 +260,9 @@ class _Linear(torch.autograd.Function):
and
not
in_fp8_activation_recompute_phase
()
)
weight_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
elif
isinstance
(
weight
,
QuantizedTensor
):
# If weight is already quantized, no need to set quantizer states
weight_quantizer
=
weight
.
_quantizer
# Get quantized weight
update_workspace
=
is_first_microbatch
is
None
or
is_first_microbatch
weightmat
=
module
.
get_weight_workspace
(
...
...
@@ -392,11 +403,6 @@ class _Linear(torch.autograd.Function):
if
backward_needs_input
:
saved_inputmat
=
inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
if
inp
.
requires_grad
:
if
isinstance
(
weightmat
,
QuantizedTensorStorage
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
and
saved_inputmat
is
not
None
:
mark_activation_offload
(
saved_inputmat
)
...
...
@@ -442,12 +448,7 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module
ctx
.
weight_object
=
weight
# Do not offload weights and biases
weight
.
offloading_activation
=
False
weightmat
.
offloading_activation
=
False
if
bias
is
not
None
:
bias
.
offloading_activation
=
False
mark_not_offload
(
weight
,
weightmat
,
bias
)
# TODO(ksivamani): Check memory usage
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
saved_inputmat
,
...
...
@@ -477,7 +478,7 @@ class _Linear(torch.autograd.Function):
ctx
.
main_grad_func
=
lambda
:
weight
.
main_grad
ctx
.
debug
=
debug
ctx
.
experimental
=
experimental
ctx
.
custom
=
custom
ctx
.
cpu_offloading
=
cpu_offloading
ctx
.
is_first_microbatch
=
is_first_microbatch
ctx
.
use_bias
=
bias
is
not
None
...
...
@@ -647,7 +648,7 @@ class _Linear(torch.autograd.Function):
if
isinstance
(
inputmat
,
QuantizedTensorStorage
):
# Input tensor is already quantized
pass
elif
ctx
.
debug
or
ctx
.
experimental
:
elif
ctx
.
debug
or
ctx
.
custom
:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else
:
...
...
transformer_engine/pytorch/ops/_common.py
View file @
c1a1c04e
...
...
@@ -13,7 +13,7 @@ from transformer_engine_torch import FP8TensorMeta
from
..
import
torch_version
from
..quantization
import
FP8GlobalStateManager
from
..tensor.float8_tensor
import
Float8Tensor
from
..
tensor.
quantized_tensor
import
QuantizedTensorStorage
from
..quantized_tensor
import
QuantizedTensorStorage
from
..utils
import
canonicalize_dtype
...
...
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
c1a1c04e
...
...
@@ -21,7 +21,7 @@ from ...module.base import (
get_ub
,
get_workspace
,
)
from
...
tensor.
quantized_tensor
import
Quantizer
from
...quantized_tensor
import
Quantizer
from
...tensor.mxfp8_tensor
import
MXFP8Quantizer
from
...utils
import
canonicalize_device
,
canonicalize_dtype
,
clear_tensor_data
from
..basic
import
BasicLinear
,
Bias
,
ReduceScatter
...
...
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
c1a1c04e
...
...
@@ -21,7 +21,7 @@ from ...module.base import (
get_workspace
,
_2X_ACC_FPROP
,
)
from
...
tensor.
quantized_tensor
import
Quantizer
from
...quantized_tensor
import
Quantizer
from
...tensor.float8_tensor
import
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
...tensor.storage.float8_tensor_storage
import
Float8TensorStorage
from
.._common
import
maybe_dequantize
,
is_quantized_tensor
...
...
transformer_engine/pytorch/ops/fuser.py
View file @
c1a1c04e
...
...
@@ -28,7 +28,7 @@ from transformer_engine.pytorch.ops.fused import (
fuse_userbuffers_backward_linear
,
fuse_userbuffers_forward_linear
,
)
from
transformer_engine.pytorch.
tensor.
quantized_tensor
import
(
from
transformer_engine.pytorch.quantized_tensor
import
(
prepare_for_saving
,
restore_from_saved
,
)
...
...
transformer_engine/pytorch/optimizers/fused_adam.py
View file @
c1a1c04e
...
...
@@ -372,9 +372,9 @@ class FusedAdam(torch.optim.Optimizer):
"""
dtype
=
self
.
name_to_dtype_map
[
state_name
]
if
store_param_remainders
:
data
=
torch
.
zeros
_like
(
param
,
dtype
=
torch
.
int16
)
data
=
torch
.
zeros
(
param
.
shape
,
dtype
=
torch
.
int16
,
device
=
param
.
device
)
else
:
data
=
torch
.
empty
_like
(
param
,
dtype
=
dtype
)
data
=
torch
.
empty
(
param
.
shape
,
dtype
=
dtype
,
device
=
param
.
device
)
if
zero_buffer
:
data
.
zero_
()
...
...
transformer_engine/pytorch/permutation.py
View file @
c1a1c04e
...
...
@@ -10,7 +10,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine.pytorch.triton.permutation
as
triton_permutation
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.
tensor.
quantized_tensor
import
QuantizedTensor
from
transformer_engine.pytorch.quantized_tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockwiseQTensor
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
...
...
transformer_engine/pytorch/
tensor/
quantized_tensor.py
→
transformer_engine/pytorch/quantized_tensor.py
View file @
c1a1c04e
...
...
@@ -2,18 +2,29 @@
#
# See LICENSE for license information.
"""
Tensor with
quantiz
ed data
"""
"""
Pure Python base classes for
quantiz
ation.
"""
from
__future__
import
annotations
from
typing
import
Callable
,
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
from
typing
import
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
import
abc
import
copy
import
warnings
import
math
import
torch
from
torch.utils._pytree
import
tree_map
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch.tensor._quantization_helpers
import
(
_QuantizeFunc
,
_IdentityFunc
,
_stride_from_shape
,
)
_quantized_tensor_cpu_supported_ops
=
(
torch
.
ops
.
aten
.
empty_like
.
default
,
torch
.
ops
.
aten
.
copy_
.
default
,
)
class
QuantizedTensorStorage
:
...
...
@@ -30,7 +41,7 @@ class QuantizedTensorStorage:
XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (li
e
k __torch_dispatch__)."""
to behave like regular torch.Tensor (lik
e
__torch_dispatch__)."""
_quantizer
:
Optional
[
Quantizer
]
...
...
@@ -58,6 +69,12 @@ class QuantizedTensorStorage:
f
"
{
self
.
__class__
.
__name__
}
class does not implement update_usage function"
)
def
get_usages
(
self
)
->
Dict
[
str
,
bool
]:
"""Get the usage of the tensor"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement get_usages function"
)
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensorStorage
]:
"""Prepare the tensor base for saving for backward"""
raise
NotImplementedError
(
...
...
@@ -123,6 +140,7 @@ def prepare_for_saving(
t
,
t_obj
=
tensor
.
prepare_for_saving
()
tensor_list
.
extend
(
t
)
tensor_objects_list
.
append
(
t_obj
)
return
tensor_list
,
tensor_objects_list
...
...
@@ -309,72 +327,12 @@ class Quantizer(abc.ABC):
"""Returns whether or not given tensor can be quantized"""
return
True
class
_QuantizeFunc
(
torch
.
autograd
.
Function
):
"""Quantize tensor"""
@
staticmethod
def
forward
(
_ctx
:
Optional
[
torch
.
autograd
.
function
.
FunctionCtx
],
# unused
tensor
:
torch
.
Tensor
,
quantize_impl
:
Callable
,
)
->
QuantizedTensor
:
# pylint: disable=missing-function-docstring
return
quantize_impl
(
tensor
)
@
staticmethod
def
backward
(
_ctx
:
torch
.
autograd
.
function
.
FunctionCtx
,
# unused
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return
grad
,
None
class
_IdentityFunc
(
torch
.
autograd
.
Function
):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
QuantizedTensor
,
init_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
QuantizedTensor
:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided
if
init_kwargs
is
None
:
return
tensor
.
detach
()
# Construct new tensor if constructor kwargs are provided
ctx
.
input_dtype
=
tensor
.
dtype
kwargs
=
tensor
.
get_metadata
()
for
key
,
val
in
init_kwargs
.
items
():
kwargs
[
key
]
=
val
return
type
(
tensor
)(
tensor
.
shape
,
tensor
.
dtype
,
**
kwargs
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# pylint: disable=missing-function-docstring
grad_input
=
grad_output
if
grad_input
.
dtype
==
ctx
.
input_dtype
:
grad_input
=
grad_input
.
detach
()
else
:
grad_input
=
grad_input
.
to
(
ctx
.
input_dtype
)
return
grad_input
,
None
def
_stride_from_shape
(
shape
:
list
[
int
]):
if
len
(
shape
)
==
0
:
return
[]
rstride
=
[
1
]
for
d
in
reversed
(
shape
[
1
:]):
rstride
.
append
(
rstride
[
-
1
]
*
d
)
return
list
(
reversed
(
rstride
))
def
get_usages
(
self
)
->
Dict
[
str
,
bool
]:
"""Get the usage of the quantizer"""
return
{
"rowwise"
:
self
.
rowwise_usage
,
"columnwise"
:
self
.
columnwise_usage
,
}
class
QuantizedTensor
(
torch
.
Tensor
):
...
...
@@ -387,7 +345,14 @@ class QuantizedTensor(torch.Tensor):
"""
def
__new__
(
cls
,
shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
*
,
requires_grad
:
bool
=
False
):
def
__new__
(
cls
,
shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
*
,
requires_grad
:
bool
=
False
,
device
:
Optional
[
torch
.
device
]
=
None
,
):
# We are assuming only contiguous tensors
stride
=
_stride_from_shape
(
shape
)
instance
=
torch
.
Tensor
.
_make_wrapper_subclass
(
...
...
@@ -398,7 +363,7 @@ class QuantizedTensor(torch.Tensor):
dtype
=
dtype
,
layout
=
torch
.
strided
,
requires_grad
=
requires_grad
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
()
if
device
is
None
else
device
,
)
return
instance
...
...
@@ -428,6 +393,9 @@ class QuantizedTensor(torch.Tensor):
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement clear function"
)
def
__repr__
(
self
,
*
,
tensor_contents
=
None
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(data=
{
self
.
dequantize
(
dtype
=
self
.
dtype
)
}
)"
...
...
@@ -469,6 +437,26 @@ class QuantizedTensor(torch.Tensor):
if
func
==
torch
.
ops
.
aten
.
copy_
.
default
:
dst
=
args
[
0
]
src
=
args
[
1
]
if
(
isinstance
(
dst
,
QuantizedTensor
)
and
isinstance
(
src
,
QuantizedTensor
)
and
type
(
dst
.
_quantizer
)
is
type
(
src
.
_quantizer
)
and
set
(
src
.
get_usages
().
keys
())
==
set
(
dst
.
get_usages
().
keys
())
and
all
(
src
.
get_usages
()[
usage
]
==
dst
.
get_usages
()[
usage
]
for
usage
in
src
.
get_usages
().
keys
()
)
):
dst_tensors
,
dst_tensor_obj
=
dst
.
prepare_for_saving
()
src_tensors
,
src_tensor_obj
=
src
.
prepare_for_saving
()
for
dst_tensor
,
src_tensor
in
zip
(
dst_tensors
,
src_tensors
):
if
dst_tensor
is
not
None
:
dst_tensor
.
copy_
(
src_tensor
,
*
args
[
2
:],
**
kwargs
)
dst_tensor_obj
.
restore_from_saved
(
dst_tensors
)
src_tensor_obj
.
restore_from_saved
(
src_tensors
)
return
None
if
isinstance
(
dst
,
QuantizedTensor
):
dst
.
quantize_
(
src
)
else
:
...
...
@@ -481,6 +469,36 @@ class QuantizedTensor(torch.Tensor):
if
func
==
torch
.
ops
.
aten
.
view
.
default
:
raise
NotImplementedError
(
"{cls.__name__} class does not support tensor views"
)
# Empty like op
if
func
==
torch
.
ops
.
aten
.
empty_like
.
default
:
tensor
=
args
[
0
]
device
=
kwargs
.
get
(
"device"
,
tensor
.
device
)
requires_grad
=
kwargs
.
get
(
"requires_grad"
,
tensor
.
requires_grad
)
pin_memory
=
kwargs
.
get
(
"pin_memory"
,
False
)
usage
=
tensor
.
get_usages
()
quantizer_usage
=
tensor
.
_quantizer
.
get_usages
()
tensor
.
_quantizer
.
set_usage
(
**
usage
)
out
=
tensor
.
_quantizer
.
make_empty
(
shape
=
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
device
,
requires_grad
=
requires_grad
,
pin_memory
=
pin_memory
,
)
tensor
.
_quantizer
.
set_usage
(
**
quantizer_usage
)
return
out
if
func
==
torch
.
ops
.
aten
.
numel
.
default
:
tensor
=
args
[
0
]
return
math
.
prod
(
tensor
.
size
())
if
func
==
torch
.
ops
.
aten
.
is_pinned
.
default
:
tensor
=
args
[
0
]
for
t
in
tensor
.
get_data_tensors
():
if
t
is
not
None
:
return
func
(
t
)
return
False
# Or error out?
def
maybe_unwrap
(
arg
):
if
isinstance
(
arg
,
QuantizedTensor
):
return
arg
.
dequantize
(
dtype
=
arg
.
dtype
)
...
...
@@ -495,6 +513,10 @@ class QuantizedTensor(torch.Tensor):
and
schema_arg
.
alias_info
.
is_write
):
arg
.
quantize_
(
new_arg
)
elif
isinstance
(
arg
,
list
)
and
isinstance
(
new_arg
,
list
):
# Recursively handle update for lists of tensors
for
a
,
na
in
zip
(
arg
,
new_arg
):
maybe_update_inplace
(
a
,
na
,
schema_arg
)
# In-place op: dequantize, perform op, and quantize
if
func
.
_schema
.
is_mutable
:
...
...
@@ -521,6 +543,16 @@ class QuantizedTensor(torch.Tensor):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
if
kwargs
is
None
:
kwargs
=
{}
def
check_if_cpu
(
arg
):
if
isinstance
(
cls
,
QuantizedTensor
)
and
arg
.
device
.
type
==
"cpu"
:
assert
(
func
in
_quantized_tensor_cpu_supported_ops
),
f
"QuantizedTensor on CPU does not support this operation:
{
func
}
"
return
arg
args
=
tree_map
(
check_if_cpu
,
args
)
# Do not force the QuantizedTensor type on the returned tensor
return
torch
.
_C
.
_disabled_torch_function_impl
(
func
,
types
,
args
,
kwargs
)
...
...
@@ -551,20 +583,16 @@ class QuantizedTensor(torch.Tensor):
shape
:
Optional
[
Iterable
[
int
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
requires_grad
:
bool
=
False
,
data
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
QuantizedTensor
:
"""Create new quantized tensor
By default, new tensor has the same attributes and underlying
data.
data.
This function is intended to create view of tensors.
"""
if
shape
is
None
:
shape
=
data
.
shape
if
data
is
not
None
else
tensor
.
shape
shape
=
shape
if
shape
is
not
None
else
tensor
.
shape
dtype
=
dtype
if
dtype
is
not
None
else
tensor
.
dtype
kwargs
=
tensor
.
get_metadata
()
if
data
is
not
None
:
kwargs
[
"data"
]
=
data
return
cls
(
shape
=
shape
,
dtype
=
dtype
,
requires_grad
=
requires_grad
,
**
kwargs
)
def
to_dtype
(
self
,
dtype
:
torch
.
dtype
)
->
QuantizedTensor
:
...
...
transformer_engine/pytorch/setup.py
View file @
c1a1c04e
...
...
@@ -145,15 +145,25 @@ if __name__ == "__main__":
)
]
# Setup version and requirements.
# Having the framework extension depend on the core lib allows
# us to detect CUDA version dynamically during compilation and
# choose the correct wheel for te core lib.
__version__
=
te_version
()
cuda_major_version
=
parse
(
torch
.
version
.
cuda
).
major
assert
cuda_major_version
in
(
12
,
13
),
f
"Unsupported cuda version
{
torch
.
version
.
cuda
}
."
te_core
=
f
"transformer_engine_cu
{
cuda_major_version
}
==
{
__version__
}
"
install_requires
=
install_requirements
()
+
[
te_core
]
# Configure package
setuptools
.
setup
(
name
=
PACKAGE_NAME
,
version
=
te
_version
()
,
version
=
_
_version
__
,
description
=
"Transformer acceleration library - Torch Lib"
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
,
"bdist_wheel"
:
CachedWheelsCommand
},
python_requires
=
f
">=
{
min_python_version_str
()
}
"
,
install_requires
=
install_require
ments
()
,
install_requires
=
install_require
s
,
tests_require
=
test_requirements
(),
)
if
any
(
x
in
sys
.
argv
for
x
in
(
"."
,
"sdist"
,
"bdist_wheel"
)):
...
...
transformer_engine/pytorch/tensor/__init__.py
View file @
c1a1c04e
...
...
@@ -6,7 +6,7 @@
import
torch
from
.quantized_tensor
import
(
from
.
.quantized_tensor
import
(
QuantizedTensorStorage
,
QuantizedTensor
,
Quantizer
,
...
...
transformer_engine/pytorch/tensor/_quantization_helpers.py
0 → 100644
View file @
c1a1c04e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Private helper functions and classes for quantized tensor implementations.
This module contains internal autograd functions and utilities that support
the quantization machinery.
"""
from
__future__
import
annotations
from
typing
import
Callable
,
Optional
,
Tuple
,
Any
,
Dict
,
TYPE_CHECKING
import
torch
if
TYPE_CHECKING
:
from
transformer_engine.pytorch.quantized_tensor
import
QuantizedTensor
class
_QuantizeFunc
(
torch
.
autograd
.
Function
):
"""Quantize tensor"""
@
staticmethod
def
forward
(
_ctx
:
Optional
[
torch
.
autograd
.
function
.
FunctionCtx
],
# unused
tensor
:
torch
.
Tensor
,
quantize_impl
:
Callable
,
)
->
QuantizedTensor
:
# pylint: disable=missing-function-docstring
return
quantize_impl
(
tensor
)
@
staticmethod
def
backward
(
_ctx
:
torch
.
autograd
.
function
.
FunctionCtx
,
# unused
grad
:
torch
.
Tensor
,
)
->
Tuple
[
Optional
[
torch
.
Tensor
],
...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return
grad
,
None
class
_IdentityFunc
(
torch
.
autograd
.
Function
):
"""Identity function
If constructor keyword-arguments are provided, then construct a
new Float8Tensor using the provided tensor's attributes.
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
QuantizedTensor
,
init_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
QuantizedTensor
:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided
if
init_kwargs
is
None
:
return
tensor
.
detach
()
# Construct new tensor if constructor kwargs are provided
ctx
.
input_dtype
=
tensor
.
dtype
kwargs
=
tensor
.
get_metadata
()
for
key
,
val
in
init_kwargs
.
items
():
kwargs
[
key
]
=
val
return
type
(
tensor
)(
tensor
.
shape
,
tensor
.
dtype
,
**
kwargs
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# pylint: disable=missing-function-docstring
grad_input
=
grad_output
if
grad_input
.
dtype
==
ctx
.
input_dtype
:
grad_input
=
grad_input
.
detach
()
else
:
grad_input
=
grad_input
.
to
(
ctx
.
input_dtype
)
return
grad_input
,
None
def
_stride_from_shape
(
shape
:
list
[
int
]):
"""Calculate stride from shape for contiguous tensors"""
if
len
(
shape
)
==
0
:
return
[]
rstride
=
[
1
]
for
d
in
reversed
(
shape
[
1
:]):
rstride
.
append
(
rstride
[
-
1
]
*
d
)
return
list
(
reversed
(
rstride
))
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
c1a1c04e
...
...
@@ -15,11 +15,8 @@ from transformer_engine_torch import Float8BlockScaleTensorFormat
from
transformer_engine.common.recipe
import
Float8BlockScaling
,
Recipe
from
.storage.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
_IdentityFunc
,
)
from
..quantized_tensor
import
QuantizedTensor
,
Quantizer
from
._quantization_helpers
import
_IdentityFunc
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
...
...
@@ -220,6 +217,7 @@ class Float8BlockQuantizer(Quantizer):
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
False
,
pin_memory
:
bool
=
False
,
)
->
Float8BlockwiseQTensor
:
"""Construct quantized tensor with uninitialized data"""
if
device
is
None
:
...
...
@@ -235,12 +233,13 @@ class Float8BlockQuantizer(Quantizer):
data
=
None
scale_inv
=
None
if
self
.
rowwise_usage
:
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
)
scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
False
)
scale_inv
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
,
)
# Allocate FP8 data transpose if needed
...
...
@@ -248,13 +247,17 @@ class Float8BlockQuantizer(Quantizer):
columnwise_scale_inv
=
None
if
self
.
columnwise_usage
:
columnwise_data
=
torch
.
empty
(
self
.
get_columnwise_shape
(
shape
),
dtype
=
torch
.
uint8
,
device
=
device
self
.
get_columnwise_shape
(
shape
),
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
,
)
columnwise_scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
True
)
columnwise_scale_inv
=
torch
.
empty
(
columnwise_scale_shape
,
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
,
)
# Construct FP8 tensor
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
c1a1c04e
...
...
@@ -4,21 +4,18 @@
"""Tensor class with FP8 data"""
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
,
Union
from
typing
import
Any
,
Optional
,
Tuple
,
Iterable
,
Union
import
warnings
import
torch
from
torch.distributed.fsdp._fully_shard._fsdp_common
import
TrainingState
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
DelayedScaling
,
Float8CurrentScaling
,
Recipe
from
..utils
import
canonicalize_process_group
,
devices_match
from
.storage.float8_tensor_storage
import
Float8TensorStorage
,
_FromFloat8Func
from
.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
_IdentityFunc
,
)
from
..quantized_tensor
import
QuantizedTensor
,
Quantizer
from
._quantization_helpers
import
_IdentityFunc
from
..constants
import
dist_group_type
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8_tensorwise
...
...
@@ -105,6 +102,7 @@ class Float8Quantizer(Quantizer):
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
False
,
pin_memory
:
bool
=
False
,
)
->
Float8Tensor
:
# Canonicalize tensor attributes
...
...
@@ -112,16 +110,19 @@ class Float8Quantizer(Quantizer):
device
=
torch
.
device
(
"cuda"
)
# Allocate FP8 data
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
data
=
None
if
self
.
rowwise_usage
:
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
)
# Allocate FP8 data transpose if needed
data_transpose
=
None
if
self
.
columnwise_usage
:
transpose_shape
=
[
data
.
size
(
-
1
)
]
+
list
(
data
.
shape
[:
-
1
])
transpose_shape
=
[
shape
[
-
1
]
]
+
list
(
shape
[:
-
1
])
data_transpose
=
torch
.
empty
(
transpose_shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
,
)
# Construct FP8 tensor
...
...
@@ -129,7 +130,7 @@ class Float8Quantizer(Quantizer):
shape
=
shape
,
dtype
=
dtype
,
data
=
data
,
fp8_scale_inv
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
),
fp8_scale_inv
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
),
fp8_dtype
=
self
.
dtype
,
requires_grad
=
requires_grad
,
data_transpose
=
data_transpose
,
...
...
@@ -291,6 +292,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
False
,
pin_memory
:
bool
=
False
,
)
->
Float8Tensor
:
# Canonicalize tensor attributes
...
...
@@ -298,25 +300,26 @@ class Float8CurrentScalingQuantizer(Quantizer):
device
=
torch
.
device
(
"cuda"
)
# Allocate FP8 data
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
data
=
None
if
self
.
rowwise_usage
:
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
)
# Allocate FP8 data transpose if needed
data_transpose
=
None
if
self
.
columnwise_usage
:
inner_dim
=
data
.
size
(
-
1
)
transpose_shape
=
[
shape
[
-
1
]]
+
list
(
shape
[:
-
1
]
)
data_transpose
=
torch
.
empty
(
inner_dim
,
data
.
numel
()
//
inner_dim
,
transpose_shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
,
)
# Construct FP8 tensor
return
Float8Tensor
(
shape
=
shape
,
dtype
=
dtype
,
data
=
data
,
fp8_scale_inv
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
),
fp8_scale_inv
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
),
fp8_dtype
=
self
.
dtype
,
requires_grad
=
requires_grad
,
data_transpose
=
data_transpose
,
...
...
@@ -538,9 +541,36 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
self
.
_transpose
=
None
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
def
make_like
(
cls
,
tensor
:
QuantizedTensor
,
*
,
shape
:
Optional
[
Iterable
[
int
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
requires_grad
:
bool
=
False
,
data
:
Optional
[
torch
.
Tensor
]
=
None
,
data_transpose
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
QuantizedTensor
:
"""Create new quantized tensor
By default, new tensor has the same attributes and underlying
data.
# View op
"""
if
shape
is
None
and
data
is
not
None
:
shape
=
data
.
shape
new_tensor
=
super
().
make_like
(
tensor
,
shape
=
shape
,
dtype
=
dtype
,
requires_grad
=
requires_grad
)
if
data
is
not
None
:
new_tensor
.
_data
=
data
if
data_transpose
is
not
None
:
new_tensor
.
_transpose
=
data_transpose
new_tensor
.
_transpose_invalid
=
False
return
new_tensor
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
if
func
==
aten
.
view
.
default
:
tensor
=
args
[
0
]
data
=
tensor
.
_data
...
...
@@ -559,6 +589,9 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
or
out_transpose_shape
[
1
:]
!=
out_shape
[:
-
1
]
):
out_transpose
=
None
else
:
view_shape_for_transpose
=
[
out_shape
[
-
1
]]
+
list
(
out_shape
[:
-
1
])
out_transpose
=
out_transpose
.
view
(
*
view_shape_for_transpose
)
return
Float8Tensor
(
shape
=
out_shape
,
dtype
=
tensor
.
dtype
,
...
...
@@ -591,11 +624,37 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[
data
]
+
list
(
args
[
1
:]),
kwargs
,
)
return
[
Float8Tensor
.
make_like
(
tensor
,
data
=
split_tensor
,
shape
=
split_tensor
.
shape
)
for
split_tensor
in
func_out
t_func_out
=
[
None
]
*
len
(
func_out
)
# Compute corresponding split of the transpose cache if available
if
tensor
.
_transpose
is
not
None
and
not
tensor
.
_transpose_invalid
:
transpose
=
tensor
.
_transpose
ndim
=
data
.
dim
()
# Figure out the original split dim
if
"dim"
in
kwargs
:
dim_to_split
=
kwargs
[
"dim"
]
else
:
dim_to_split
=
args
[
2
]
if
len
(
args
)
>
2
else
0
# Dimension along which transpose needs to be split
t_dim
=
0
if
dim_to_split
==
ndim
-
1
else
dim_to_split
+
1
t_func_out
=
transpose
.
__torch_dispatch__
(
func
,
types
,
[
transpose
,
args
[
1
],
t_dim
],
kwargs
,
)
outs
=
[
Float8Tensor
.
make_like
(
tensor
,
data
=
split_tensor
,
data_transpose
=
split_transpose_tensor
,
shape
=
split_tensor
.
shape
,
)
for
split_tensor
,
split_transpose_tensor
in
zip
(
func_out
,
t_func_out
)
]
return
outs
if
func
==
aten
.
new_zeros
.
default
:
# create fresh new tensor with zeros.
tensor
=
args
[
0
]
data
=
tensor
.
_data
func_out
=
data
.
__torch_dispatch__
(
...
...
@@ -604,28 +663,82 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[
data
]
+
list
(
args
[
1
:]),
kwargs
,
)
return
Float8Tensor
.
make_like
(
tensor
,
data
=
func_out
,
shape
=
func_out
.
shape
)
func_transposed_out
=
None
if
tensor
.
_transpose
is
not
None
and
not
tensor
.
_transpose_invalid
:
transpose
=
tensor
.
_transpose
size
=
args
[
1
]
t_shape
=
[
size
[
-
1
]]
+
list
(
size
[:
-
1
])
func_transposed_out
=
transpose
.
__torch_dispatch__
(
func
,
types
,
[
transpose
,
t_shape
]
+
list
(
args
[
2
:]),
kwargs
,
)
# deep copy the scale inverse tensor and quantizer as well.
scale_inv
=
tensor
.
_scale_inv
.
detach
().
clone
()
quantizer
=
tensor
.
_quantizer
.
copy
()
out_tensor
=
Float8Tensor
(
data
=
func_out
,
shape
=
func_out
.
shape
,
dtype
=
tensor
.
dtype
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
fp8_scale_inv
=
scale_inv
,
data_transpose
=
func_transposed_out
,
quantizer
=
quantizer
,
)
return
out_tensor
if
func
==
torch
.
ops
.
aten
.
as_strided
.
default
:
tensor
=
args
[
0
]
data
=
tensor
.
_data
# Apply as_strided to the primary uint8 data
func_out
=
data
.
__torch_dispatch__
(
func
,
types
,
[
data
]
+
list
(
args
[
1
:]),
kwargs
,
)
return
Float8Tensor
.
make_like
(
tensor
,
data
=
func_out
,
shape
=
func_out
.
shape
)
func_transposed_out
=
None
if
tensor
.
_transpose
is
not
None
and
not
tensor
.
_transpose_invalid
:
transpose
=
tensor
.
_transpose
size
=
args
[
1
]
stride
=
args
[
2
]
if
"storage_offset"
in
kwargs
:
storage_offset
=
kwargs
[
"storage_offset"
]
else
:
storage_offset
=
args
[
3
]
if
len
(
args
)
>
3
else
0
# Shape and strided needed for transpose matrix
t_size
=
[
size
[
-
1
]]
+
list
(
size
[:
-
1
])
t_stride
=
[
stride
[
-
1
]]
+
list
(
stride
[:
-
1
])
func_transposed_out
=
transpose
.
__torch_dispatch__
(
func
,
types
,
[
transpose
,
t_size
,
t_stride
,
storage_offset
]
+
list
(
args
[
4
:]),
kwargs
,
)
return
Float8Tensor
.
make_like
(
tensor
,
data
=
func_out
,
data_transpose
=
func_transposed_out
,
shape
=
func_out
.
shape
)
if
func
==
torch
.
ops
.
aten
.
detach
.
default
:
return
cls
.
detach
(
args
[
0
])
if
func
==
torch
.
ops
.
aten
.
clone
.
default
:
return
cls
.
clone
(
args
[
0
])
if
func
==
torch
.
ops
.
aten
.
copy_
.
default
:
dst
,
src
=
args
[
0
],
args
[
1
]
# Just copy FP8 attrs if copying between Float8Tensors
if
isinstance
(
src
,
Float8Tensor
)
and
isinstance
(
dst
,
Float8Tensor
):
dst
.
_data
.
copy_
(
src
.
_data
.
detach
())
dst
.
_scale_inv
.
copy_
(
src
.
_scale_inv
.
view
(
dst
.
_scale_inv
.
size
()))
if
src
.
_transpose
is
not
None
or
dst
.
_transpose
is
not
None
:
if
dst
.
_data
is
not
None
:
dst
.
_data
.
copy_
(
src
.
_data
.
detach
(),
*
args
[
2
:],
**
kwargs
)
if
dst
.
_scale_inv
is
not
None
:
dst
.
_scale_inv
.
copy_
(
src
.
_scale_inv
.
view
(
dst
.
_scale_inv
.
size
()),
*
args
[
2
:],
**
kwargs
)
if
dst
.
_transpose
is
not
None
and
not
dst
.
_transpose_invalid
:
if
not
src
.
_transpose_invalid
:
dst
.
_transpose
.
copy_
(
src
.
_transpose
,
*
args
[
2
:],
**
kwargs
)
else
:
dst
.
_create_transpose
()
return
dst
elif
func
in
_ops_to_preserve_subclass_in_fsdp2
:
...
...
@@ -636,9 +749,105 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
)
else
:
pass
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
def
fsdp_pre_all_gather
(
self
,
mesh
,
orig_size
,
contiguous_orig_stride
,
module
,
mp_policy
):
"""Functions FSDP2 calls before all-gather of the
weights for both forward and backward passes.
Args:
mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2
to shard the weights.
orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape)
contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor
(For us same as self.stride())
module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard
that contains this FP8 tensor.
mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2.
Returns:
shareded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors
that need to be all-gathered.(In this case uint8 data tensor)
metadata: Tuple[Any]: Metadata needed for reconstructing the
Float8Tensor after all-gather.
"""
# pylint: disable=unused-argument
# Importing here to avoid circular imports
from
transformer_engine.pytorch.distributed
import
_get_module_fsdp_state
if
isinstance
(
self
.
_quantizer
,
Float8CurrentScalingQuantizer
)
and
mesh
is
not
None
:
# When sharded weight is updated after reduce scattering the gradients in FSDP2,
# we need to do amax reduction across the mesh to make sure all weight shards are
# updated with same scale inverse. Setting the state below in the quantizer will make
# sure that updated Quantized weight tensor have same scale inverse across all shards.
self
.
_quantizer
.
amax_reduction_group
=
mesh
.
get_group
()
self
.
_quantizer
.
with_amax_reduction
=
True
quantizer
=
self
.
_quantizer
.
copy
()
# quantizer to be used for allgathered weights
fsdp_state
=
_get_module_fsdp_state
(
module
)
reshard_after_forward
=
fsdp_state
.
_fsdp_param_group
.
_reshard_after_forward
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward and so we dont change the quantizer usages which might need
# both rowwise and columnwise usages.
if
reshard_after_forward
:
training_state
=
fsdp_state
.
_fsdp_param_group
.
_training_state
is_backward_pass
=
training_state
==
TrainingState
.
PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately.
quantizer
.
set_usage
(
rowwise
=
not
is_backward_pass
,
columnwise
=
is_backward_pass
)
sharded_tensors
=
(
self
.
_data
,)
metadata
=
(
self
.
_scale_inv
,
self
.
_fp8_dtype
,
quantizer
)
return
sharded_tensors
,
metadata
def
fsdp_post_all_gather
(
self
,
all_gather_outputs
:
Tuple
[
torch
.
Tensor
,
...],
metadata
:
Any
,
param_dtype
:
torch
.
dtype
,
*
,
out
:
Optional
[
Float8Tensor
]
=
None
,
):
"""Functions FSDP2 calls after all-gather of the
weights for both forward and backward passes.
Args:
all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank
are all-gathered and received here as a tuple.
metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the Float8Tensor.
param_dtype (torch.dtype): high precision dtype of the Float8Tensor.
out (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
Tuple[Float8Tensor, Tuple[torch.Tensor, ...]]: Allgathered Float8Tensor and tuple of internal tensors
used by the Float8Tensor that was being computed after allgather.
"""
(
data
,)
=
all_gather_outputs
(
fp8_scale_inv
,
fp8_dtype
,
quantizer
)
=
metadata
orig_shape
=
data
.
size
()
# Quantizer has only columnwise usage set for backward pass
# In Blackwell+ architectures, transpose is not needed at all,
# even if columnwise usage is set. and is going to be handled
# internally in the update_usage method.
if
out
is
not
None
:
out
.
_data
=
data
else
:
fp8_args
=
{
"shape"
:
orig_shape
,
"dtype"
:
param_dtype
,
"fp8_scale_inv"
:
fp8_scale_inv
,
"fp8_dtype"
:
fp8_dtype
,
"quantizer"
:
quantizer
,
"requires_grad"
:
False
,
"data"
:
data
,
}
out
=
Float8Tensor
(
**
fp8_args
)
out
.
update_usage
(
rowwise_usage
=
quantizer
.
rowwise_usage
,
columnwise_usage
=
quantizer
.
columnwise_usage
,
)
return
out
,
all_gather_outputs
@
classmethod
def
_make_in_reduce_ex
(
cls
,
...
...
@@ -756,6 +965,9 @@ class _ViewFunc(torch.autograd.Function):
out_transpose_shape
=
out_transpose
.
size
()
if
out_transpose_shape
[
0
]
!=
out_shape
[
-
1
]
or
out_transpose_shape
[
1
:]
!=
out_shape
[:
-
1
]:
out_transpose
=
None
else
:
view_shape_for_transpose
=
[
shape
[
-
1
]]
+
list
(
shape
[:
-
1
])
out_transpose
=
out_transpose
.
view
(
*
view_shape_for_transpose
)
return
Float8Tensor
(
shape
=
out_shape
,
dtype
=
tensor
.
dtype
,
...
...
@@ -800,6 +1012,9 @@ class _ReshapeFunc(torch.autograd.Function):
out_transpose_shape
=
out_transpose
.
size
()
if
out_transpose_shape
[
0
]
!=
out_shape
[
-
1
]
or
out_transpose_shape
[
1
:]
!=
out_shape
[:
-
1
]:
out_transpose
=
None
else
:
reshape_shape_for_transpose
=
[
shape
[
-
1
]]
+
list
(
shape
[:
-
1
])
out_transpose
=
out_transpose
.
reshape
(
*
reshape_shape_for_transpose
)
return
Float8Tensor
(
shape
=
out_shape
,
dtype
=
tensor
.
dtype
,
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
c1a1c04e
...
...
@@ -6,22 +6,20 @@
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
math
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
,
Any
import
warnings
import
torch
from
torch.distributed.fsdp._fully_shard._fsdp_common
import
TrainingState
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
MXFP8BlockScaling
,
Recipe
from
..constants
import
MXFP8_BLOCK_SCALING_SIZE
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
from
.storage.mxfp8_tensor_storage
import
MXFP8TensorStorage
,
_FromMXFP8Func
from
.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
_IdentityFunc
,
)
from
..quantized_tensor
import
QuantizedTensor
,
Quantizer
from
._quantization_helpers
import
_IdentityFunc
aten
=
torch
.
ops
.
aten
...
...
@@ -92,6 +90,7 @@ class MXFP8Quantizer(Quantizer):
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
requires_grad
:
bool
=
False
,
pin_memory
:
bool
=
False
,
)
->
MXFP8Tensor
:
# Canonicalize tensor attributes
...
...
@@ -107,24 +106,29 @@ class MXFP8Quantizer(Quantizer):
)
# Allocate FP8 data
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
data
=
None
scale_inv
=
None
if
self
.
rowwise_usage
:
data
=
torch
.
empty
(
shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
)
scale_inv
=
torch
.
empty
(
round_up_to_nearest_multiple
(
math
.
prod
(
shape
[:
-
1
]),
128
),
round_up_to_nearest_multiple
(
shape
[
-
1
]
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
,
)
# Allocate FP8 data transpose if needed
columnwise_data
=
None
columnwise_scale_inv
=
None
if
self
.
columnwise_usage
:
columnwise_data
=
torch
.
empty_like
(
data
)
columnwise_data
=
torch
.
empty_like
(
data
,
pin_memory
=
pin_memory
)
columnwise_scale_inv
=
torch
.
empty
(
round_up_to_nearest_multiple
(
math
.
prod
(
shape
[:
-
1
])
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
round_up_to_nearest_multiple
(
shape
[
-
1
],
128
),
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
,
)
# Construct FP8 tensor
...
...
@@ -301,7 +305,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
memory_format
:
torch
.
memory_format
=
torch
.
contiguous_format
,
)
->
MXFP8Tensor
:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
...
...
@@ -317,7 +320,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
# View op
if
func
==
aten
.
view
.
default
:
tensor
=
args
[
0
]
...
...
@@ -341,9 +343,339 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fp8_dtype
=
tensor
.
_fp8_dtype
,
)
if
func
==
torch
.
ops
.
aten
.
copy_
.
default
:
dst
,
src
=
args
[
0
],
args
[
1
]
if
isinstance
(
src
,
MXFP8Tensor
)
and
isinstance
(
dst
,
MXFP8Tensor
):
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages.
# If not, default to base class behavior.
rowwise_matches
=
src
.
_rowwise_data
is
not
None
or
dst
.
_rowwise_data
is
None
columnwise_matches
=
(
src
.
_columnwise_data
is
not
None
or
dst
.
_columnwise_data
is
None
)
if
rowwise_matches
and
columnwise_matches
:
if
dst
.
_rowwise_data
is
not
None
:
dst
.
_rowwise_data
.
copy_
(
src
.
_rowwise_data
.
detach
(),
*
args
[
2
:],
**
kwargs
)
dst
.
_rowwise_scale_inv
.
copy_
(
src
.
_rowwise_scale_inv
.
detach
(),
*
args
[
2
:],
**
kwargs
)
if
dst
.
_columnwise_data
is
not
None
:
dst
.
_columnwise_data
.
copy_
(
src
.
_columnwise_data
.
detach
(),
*
args
[
2
:],
**
kwargs
)
dst
.
_columnwise_scale_inv
.
copy_
(
src
.
_columnwise_scale_inv
.
detach
(),
*
args
[
2
:],
**
kwargs
)
return
dst
# FSDP2 related functions.
if
func
==
aten
.
split
.
Tensor
:
# This is called if entire model is initialized on CUDA device and
# then splitted. Finally the shard needed by the process is used
# and other splitted shards are discarded.
if
"dim"
in
kwargs
:
dim_to_split
=
kwargs
[
"dim"
]
else
:
dim_to_split
=
args
[
2
]
if
len
(
args
)
>
2
else
0
tensor
=
args
[
0
]
split_size
=
args
[
1
]
dim0_size
=
tensor
.
size
(
0
)
dimlast_size
=
math
.
prod
(
tensor
.
shape
[
1
:])
if
(
dim0_size
%
split_size
!=
0
or
dim_to_split
!=
0
or
split_size
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
or
dimlast_size
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
):
# Handle splitting by dequantizing and splitting the hp tensor
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
out_data
=
[]
for
data
in
[
tensor
.
_rowwise_data
,
tensor
.
_columnwise_data
]:
func_out
=
(
data
.
__torch_dispatch__
(
func
,
types
,
[
data
]
+
list
(
args
[
1
:]),
kwargs
,
)
if
data
is
not
None
else
None
)
out_data
.
append
(
func_out
)
scale_invs
=
[
tensor
.
_rowwise_scale_inv
,
tensor
.
_columnwise_scale_inv
]
split_sizes_for_scale
=
[
split_size
,
split_size
//
MXFP8_BLOCK_SCALING_SIZE
]
# Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4
padding_multiples
=
[
128
,
4
]
for
scale_inv
,
scale_split_size
,
pad_multiple
in
zip
(
scale_invs
,
split_sizes_for_scale
,
padding_multiples
):
scale_inv_out
=
(
scale_inv
.
__torch_dispatch__
(
func
,
types
,
[
scale_inv
,
scale_split_size
]
+
list
(
args
[
2
:]),
kwargs
,
)
if
scale_inv
is
not
None
else
None
)
# Pad scale_inv_out to be a multiple of pad_multiple
if
scale_inv_out
is
not
None
:
current_shape
=
scale_inv_out
.
shape
pad_dim0
=
(
pad_multiple
-
current_shape
[
0
]
%
pad_multiple
)
%
pad_multiple
if
pad_dim0
>
0
:
scale_inv_out
=
torch
.
nn
.
functional
.
pad
(
scale_inv_out
,
(
0
,
0
,
0
,
pad_dim0
))
out_data
.
append
(
scale_inv_out
)
return
[
MXFP8Tensor
(
shape
=
(
splitted_tensor_data
[
0
].
size
()
if
splitted_tensor_data
[
0
]
is
not
None
else
splitted_tensor_data
[
1
].
size
()
),
dtype
=
tensor
.
dtype
,
rowwise_data
=
splitted_tensor_data
[
0
],
rowwise_scale_inv
=
splitted_tensor_data
[
2
],
columnwise_data
=
splitted_tensor_data
[
1
],
columnwise_scale_inv
=
splitted_tensor_data
[
3
],
quantizer
=
tensor
.
_quantizer
,
requires_grad
=
False
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
)
for
splitted_tensor_data
in
zip
(
*
out_data
)
]
if
func
==
torch
.
ops
.
aten
.
as_strided
.
default
:
# Applied on unsharded param in FSDP2. In our case, this should be a no-op
# This is needed for the case where some MXFP8 shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision.
# If weight doesnt need padding, this is just a no-op.
shape
=
args
[
1
]
strides
=
args
[
2
]
tensor
=
args
[
0
]
if
(
len
(
shape
)
!=
2
or
len
(
strides
)
!=
2
or
strides
[
1
]
!=
1
or
shape
[
0
]
!=
tensor
.
shape
[
0
]
or
shape
[
1
]
!=
tensor
.
shape
[
1
]
):
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
return
MXFP8Tensor
.
make_like
(
tensor
)
if
func
==
aten
.
slice
.
Tensor
:
# FSDP2 needed function.
# We need slicing for the case where some MXFP8 weight shards need padding i.e dimension 0
# of the unsharded param is not a multiple of the world size. If that is the case,
# we down the dequantization route and weights are allgathered in high precision instead.
# If sharded weight doesnt have padding, this is just a no-op.
dim
=
args
[
1
]
start
=
args
[
2
]
length
=
args
[
3
]
tensor
=
args
[
0
]
if
(
dim
!=
0
or
length
!=
tensor
.
shape
[
0
]
or
start
!=
0
or
length
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
or
start
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
):
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
return
MXFP8Tensor
.
make_like
(
tensor
)
if
func
==
aten
.
new_zeros
.
default
:
rowwise_data
=
None
columnwise_data
=
None
rowwise_scale_inv
=
None
columnwise_scale_inv
=
None
tensor
=
args
[
0
]
shape
=
args
[
1
]
first_dim
=
math
.
prod
(
shape
[:
-
1
])
last_dim
=
shape
[
-
1
]
if
(
first_dim
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
or
last_dim
%
MXFP8_BLOCK_SCALING_SIZE
!=
0
):
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
rowwise_scale_inv_shape
=
[
first_dim
,
last_dim
//
MXFP8_BLOCK_SCALING_SIZE
]
columnwise_scale_inv_shape
=
[
first_dim
//
MXFP8_BLOCK_SCALING_SIZE
,
last_dim
,
]
if
tensor
.
_rowwise_data
is
not
None
:
rowwise_data
=
tensor
.
_rowwise_data
.
__torch_dispatch__
(
func
,
types
,
[
tensor
.
_rowwise_data
]
+
list
(
args
[
1
:]),
kwargs
,
)
rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
.
__torch_dispatch__
(
func
,
types
,
[
tensor
.
_rowwise_scale_inv
,
rowwise_scale_inv_shape
]
+
list
(
args
[
2
:]),
kwargs
,
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_data
=
tensor
.
_columnwise_data
.
__torch_dispatch__
(
func
,
types
,
[
tensor
.
_columnwise_data
]
+
list
(
args
[
1
:]),
kwargs
,
)
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
.
__torch_dispatch__
(
func
,
types
,
[
tensor
.
_columnwise_scale_inv
,
columnwise_scale_inv_shape
]
+
list
(
args
[
2
:]),
kwargs
,
)
return
MXFP8Tensor
(
shape
=
args
[
1
],
dtype
=
tensor
.
dtype
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
quantizer
=
tensor
.
_quantizer
.
copy
(),
requires_grad
=
False
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
)
# Default case
return
super
().
__torch_dispatch__
(
func
,
types
,
args
,
kwargs
)
def
fsdp_pre_all_gather
(
self
,
mesh
,
orig_size
,
contiguous_orig_stride
,
module
,
mp_policy
):
"""Functions FSDP2 calls before all-gather of the
weights for both forward and backward passes.
Args:
mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2
to shard the weights.
orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape)
contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor
(For us same as self.stride()).
module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard
that contains this MXFP8 tensor.
mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2.
Returns:
sharded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors
that need to be all-gathered.
metadata: Tuple[Any]: Metadata needed for reconstructing the
MXFP8Tensor after all-gather.
"""
# pylint: disable=unused-argument
from
transformer_engine.pytorch.distributed
import
_get_module_fsdp_state
fsdp_state
=
_get_module_fsdp_state
(
module
)
reshard_after_forward
=
fsdp_state
.
_fsdp_param_group
.
_reshard_after_forward
quantizer
=
self
.
_quantizer
.
copy
()
# Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv
=
self
.
_rowwise_scale_inv
columnwise_scale_inv
=
self
.
_columnwise_scale_inv
shape
=
self
.
shape
if
rowwise_scale_inv
is
not
None
:
# Remove padding from rowwise scale_inv
flattened_in_shape0
=
math
.
prod
(
shape
[:
-
1
])
if
rowwise_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
rowwise_scale_inv
=
rowwise_scale_inv
[:
flattened_in_shape0
]
if
columnwise_scale_inv
is
not
None
:
# Remove padding from columnwise scale_inv
flattened_in_shape0
=
math
.
prod
(
shape
[:
-
1
])
//
MXFP8_BLOCK_SCALING_SIZE
if
columnwise_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
columnwise_scale_inv
=
columnwise_scale_inv
[:
flattened_in_shape0
]
sharded_tensors
=
(
self
.
_rowwise_data
,
rowwise_scale_inv
)
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself.
if
reshard_after_forward
:
training_state
=
fsdp_state
.
_fsdp_param_group
.
_training_state
is_backward_pass
=
training_state
==
TrainingState
.
PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass
quantizer
.
set_usage
(
rowwise
=
not
is_backward_pass
,
columnwise
=
is_backward_pass
)
sharded_tensors
=
(
(
self
.
_columnwise_data
,
columnwise_scale_inv
)
if
is_backward_pass
else
sharded_tensors
)
else
:
if
quantizer
.
columnwise_usage
:
# If weights are not resharded after forward, then both
# rowwise and columnwise data/scale_inv need to be allgathered.
sharded_tensors
+=
(
self
.
_columnwise_data
,
columnwise_scale_inv
)
metadata
=
(
self
.
_fp8_dtype
,
quantizer
)
return
sharded_tensors
,
metadata
def
fsdp_post_all_gather
(
self
,
all_gather_outputs
:
Tuple
[
torch
.
Tensor
,
...],
metadata
:
Any
,
param_dtype
:
torch
.
dtype
,
*
,
out
:
Optional
[
MXFP8Tensor
]
=
None
,
):
"""Functions FSDP2 calls after all-gather of the
weights for both forward and backward passes.
Args:
all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank
are all-gathered and received here as a tuple.
metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the MXFP8Tensor.
param_dtype (torch.dtype): high precision dtype of the MXFP8Tensor.
out (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors
used by the MXFP8Tensor that was being computed after allgather.
"""
fp8_dtype
,
quantizer
=
metadata
rowwise_data
,
rowwise_scale_inv
=
(
all_gather_outputs
[:
2
]
if
quantizer
.
rowwise_usage
else
(
None
,
None
)
)
columnwise_data
,
columnwise_scale_inv
=
(
all_gather_outputs
[
-
2
:]
if
quantizer
.
columnwise_usage
else
(
None
,
None
)
)
# Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise
if
rowwise_scale_inv
is
not
None
:
# Pad rowwise_scale_inv to be a multiple of [128, 4]
current_shape
=
rowwise_scale_inv
.
shape
pad_dim0
=
(
128
-
current_shape
[
0
]
%
128
)
%
128
if
pad_dim0
>
0
:
rowwise_scale_inv
=
torch
.
nn
.
functional
.
pad
(
rowwise_scale_inv
,
(
0
,
0
,
0
,
pad_dim0
))
if
columnwise_scale_inv
is
not
None
:
# Pad columnwise_scale_inv to be a multiple of [4, 128]
current_shape
=
columnwise_scale_inv
.
shape
pad_dim0
=
(
4
-
current_shape
[
0
]
%
4
)
%
4
if
pad_dim0
>
0
:
columnwise_scale_inv
=
torch
.
nn
.
functional
.
pad
(
columnwise_scale_inv
,
(
0
,
0
,
0
,
pad_dim0
)
)
if
out
is
not
None
:
out
.
_rowwise_data
=
rowwise_data
out
.
_rowwise_scale_inv
=
rowwise_scale_inv
out
.
_columnwise_data
=
columnwise_data
out
.
_columnwise_scale_inv
=
columnwise_scale_inv
out
.
_quantizer
=
quantizer
else
:
out
=
MXFP8Tensor
(
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
fp8_dtype
=
fp8_dtype
,
dtype
=
param_dtype
,
shape
=
rowwise_data
.
shape
if
rowwise_data
is
not
None
else
columnwise_data
.
shape
,
quantizer
=
quantizer
,
)
return
out
,
all_gather_outputs
@
classmethod
def
_make_in_reduce_ex
(
cls
,
...
...
@@ -481,10 +813,14 @@ class _ViewFunc(torch.autograd.Function):
shape
[
i
]
=
d_inferred
break
if
shape
[
-
1
]
!=
ctx
.
shape
[
-
1
]:
raise
RuntimeError
(
"MXFP8Tensor does not support reshaping inner dimension "
warnings
.
warn
(
"MXFP8Tensor does not support reshaping inner dimension
.
"
f
"(attempted to reshape dims=
{
tuple
(
tensor
.
shape
)
}
to
{
tuple
(
shape
)
}
)"
"If you are using this for FSDP2 without compiled_autograd_enabled,"
"then ignore this warning. Since this view is not going to be used anywhere. "
,
stacklevel
=
2
,
)
return
tensor
.
dequantize
().
view
(
*
shape
)
# Construct new tensor if shape is provided
new_rowwise_data
=
None
...
...
transformer_engine/pytorch/tensor/nvfp4_tensor.py
View file @
c1a1c04e
...
...
@@ -6,7 +6,7 @@
from
__future__
import
annotations
from
collections.abc
import
Iterable
import
math
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
import
functools
import
torch
...
...
@@ -22,14 +22,15 @@ from ..utils import (
)
from
.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
,
_FromNVFP4Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
..quantized_tensor
import
QuantizedTensor
,
Quantizer
from
._quantization_helpers
import
_IdentityFunc
aten
=
torch
.
ops
.
aten
def
get_no_random_sign_vector
()
->
torch
.
Tensor
:
"""Non-random sign vector for Hadamard transform."""
return
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
)
return
torch
.
tensor
([
1
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
def
get_sign_from_vector
(
vector
:
torch
.
Tensor
)
->
int
:
...
...
@@ -41,7 +42,7 @@ def get_sign_from_vector(vector: torch.Tensor) -> int:
mask
=
0
for
i
,
v
in
enumerate
(
vector
):
mask
|=
(
v
==
-
1
)
<<
i
return
mask
return
mask
.
item
()
def
get_wgrad_sign_vector
()
->
torch
.
Tensor
:
...
...
@@ -53,6 +54,7 @@ def get_wgrad_sign_vector() -> torch.Tensor:
return
torch
.
tensor
(
[
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
-
1
,
-
1
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
...
...
@@ -81,6 +83,7 @@ def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
[
1
,
-
1
,
-
1
,
1
,
-
1
,
1
,
1
,
-
1
,
-
1
,
1
,
1
,
-
1
,
1
,
-
1
,
-
1
,
1
],
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
*
hadamard_scale
)
...
...
@@ -94,9 +97,9 @@ def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor:
signs
=
get_wgrad_sign_vector
()
else
:
signs
=
get_no_random_sign_vector
()
sign_matrix
=
signs
*
torch
.
eye
(
hadamard_dimension
,
dtype
=
torch
.
float32
)
sign_matrix
=
signs
*
torch
.
eye
(
hadamard_dimension
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
rht_matrix
=
sign_matrix
@
get_hadamard_matrix
(
hadamard_dimension
)
return
rht_matrix
.
to
(
dtype
=
torch
.
bfloat16
)
.
cuda
()
return
rht_matrix
.
to
(
dtype
=
torch
.
bfloat16
)
@
functools
.
lru_cache
(
maxsize
=
None
)
...
...
@@ -262,6 +265,7 @@ class NVFP4Quantizer(Quantizer):
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
pin_memory
:
bool
=
False
,
requires_grad
:
bool
=
False
,
)
->
NVFP4Tensor
:
...
...
@@ -285,11 +289,18 @@ class NVFP4Quantizer(Quantizer):
scale_inv
=
None
amax_rowwise
=
None
if
self
.
rowwise_usage
:
data
=
torch
.
empty
(
self
.
convert_shape_for_fp4
(
shape
),
dtype
=
torch
.
uint8
,
device
=
device
)
data
=
torch
.
empty
(
self
.
convert_shape_for_fp4
(
shape
),
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
,
)
scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
False
)
scale_inv
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
)
scale_inv
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
)
# Allocate per tensor scale inverse. FP32 format.
amax_rowwise
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
amax_rowwise
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
)
# Allocate FP8 data transpose if needed
columnwise_data
=
None
...
...
@@ -303,12 +314,15 @@ class NVFP4Quantizer(Quantizer):
self
.
convert_shape_for_fp4
(
self
.
get_columnwise_shape
(
shape_2d
)),
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
,
)
columnwise_scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
True
)
columnwise_scale_inv
=
torch
.
empty
(
columnwise_scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
columnwise_scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
)
amax_columnwise
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
)
amax_columnwise
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# Construct FP8 tensor
return
NVFP4Tensor
(
...
...
@@ -495,6 +509,12 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
return
self
raise
ValueError
(
"NVFP4Tensor does not support different memory formats!"
)
def
get_usages
(
self
)
->
Dict
[
str
,
bool
]:
return
{
"rowwise"
:
self
.
_rowwise_data
is
not
None
,
"columnwise"
:
self
.
_columnwise_data
is
not
None
,
}
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
...
...
@@ -517,16 +537,20 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
)
if
tensor
.
_rowwise_data
is
not
None
:
rowwise_data
=
data_init_func
(
tensor
.
_rowwise_data
)
rowwise_scale_inv
=
scale_inv_init_func
(
tensor
.
_rowwise_scale_inv
)
amax_rowwise
=
torch
.
zeros_like
(
tensor
.
_amax_rowwise
)
rowwise_data
=
data_init_func
(
tensor
.
_rowwise_data
,
*
args
[
1
:],
**
kwargs
)
rowwise_scale_inv
=
scale_inv_init_func
(
tensor
.
_rowwise_scale_inv
,
*
args
[
1
:],
**
kwargs
)
amax_rowwise
=
torch
.
zeros_like
(
tensor
.
_amax_rowwise
,
*
args
[
1
:],
**
kwargs
)
else
:
rowwise_data
,
rowwise_scale_inv
,
amax_rowwise
=
None
,
None
,
None
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_data
=
data_init_func
(
tensor
.
_columnwise_data
)
columnwise_scale_inv
=
scale_inv_init_func
(
tensor
.
_columnwise_scale_inv
)
amax_columnwise
=
torch
.
zeros_like
(
tensor
.
_amax_columnwise
)
columnwise_data
=
data_init_func
(
tensor
.
_columnwise_data
,
*
args
[
1
:],
**
kwargs
)
columnwise_scale_inv
=
scale_inv_init_func
(
tensor
.
_columnwise_scale_inv
,
*
args
[
1
:],
**
kwargs
)
amax_columnwise
=
torch
.
zeros_like
(
tensor
.
_amax_columnwise
,
*
args
[
1
:],
**
kwargs
)
else
:
columnwise_data
,
columnwise_scale_inv
,
amax_columnwise
=
(
None
,
...
...
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py
View file @
c1a1c04e
...
...
@@ -14,12 +14,10 @@ from transformer_engine_torch import DType as TE_DType
from
transformer_engine_torch
import
Float8BlockScaleTensorFormat
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
..quantized_tensor
import
QuantizedTensorStorage
from
..
.
quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
from
...constants
import
TE_DType_To_Torch
from
..quantized_tensor
import
Quantizer
from
...utils
import
_empty_tensor
...
...
@@ -423,3 +421,10 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
return
return
def
get_usages
(
self
)
->
Dict
[
str
,
bool
]:
"""Get the usage of the tensor"""
return
{
"rowwise"
:
self
.
_rowwise_data
is
not
None
,
"columnwise"
:
self
.
_columnwise_data
is
not
None
,
}
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
View file @
c1a1c04e
...
...
@@ -12,12 +12,10 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensorStorage
from
..
.
quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
from
...constants
import
TE_DType
as
torch_to_transformer_engine_dtype
from
..quantized_tensor
import
Quantizer
from
...utils
import
is_non_tn_fp8_gemm_supported
,
_empty_tensor
...
...
@@ -227,3 +225,12 @@ class Float8TensorStorage(QuantizedTensorStorage):
if
not
needs_data_transpose
:
self
.
_transpose
=
None
self
.
_transpose_invalid
=
True
def
get_usages
(
self
)
->
Dict
[
str
,
bool
]:
"""Get the usage of the tensor"""
usages
=
{
"rowwise"
:
self
.
_data
is
not
None
}
if
is_non_tn_fp8_gemm_supported
():
usages
[
"columnwise"
]
=
self
.
_data
is
not
None
else
:
usages
[
"columnwise"
]
=
self
.
_transpose
is
not
None
and
not
self
.
_transpose_invalid
return
usages
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
View file @
c1a1c04e
...
...
@@ -13,12 +13,10 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensorStorage
from
..
.
quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
from
...constants
import
TE_DType
as
torch_to_transformer_engine_dtype
from
..quantized_tensor
import
Quantizer
from
...utils
import
_empty_tensor
...
...
@@ -256,3 +254,10 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
else
:
self
.
_columnwise_data
=
None
self
.
_columnwise_scale_inv
=
None
def
get_usages
(
self
)
->
Tuple
[
bool
,
bool
]:
"""Get the usage of the tensor"""
return
{
"rowwise"
:
self
.
_rowwise_data
is
not
None
,
"columnwise"
:
self
.
_columnwise_data
is
not
None
,
}
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment