Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a207db1d
Commit
a207db1d
authored
Apr 01, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
fbee8990
69365f88
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2287 additions
and
447 deletions
+2287
-447
transformer_engine/jax/quantize/helper.py
transformer_engine/jax/quantize/helper.py
+416
-0
transformer_engine/jax/quantize/metadata.py
transformer_engine/jax/quantize/metadata.py
+43
-0
transformer_engine/jax/quantize/quantizer.py
transformer_engine/jax/quantize/quantizer.py
+621
-0
transformer_engine/jax/quantize/scaling_modes.py
transformer_engine/jax/quantize/scaling_modes.py
+280
-0
transformer_engine/jax/quantize/tensor.py
transformer_engine/jax/quantize/tensor.py
+383
-0
transformer_engine/jax/setup.py
transformer_engine/jax/setup.py
+44
-1
transformer_engine/jax/sharding.py
transformer_engine/jax/sharding.py
+70
-67
transformer_engine/pytorch/attention.py
transformer_engine/pytorch/attention.py
+78
-58
transformer_engine/pytorch/csrc/common.h
transformer_engine/pytorch/csrc/common.h
+0
-1
transformer_engine/pytorch/csrc/extensions/quantizer.cpp
transformer_engine/pytorch/csrc/extensions/quantizer.cpp
+10
-13
transformer_engine/pytorch/distributed.py
transformer_engine/pytorch/distributed.py
+93
-70
transformer_engine/pytorch/dot_product_attention/inference.py
...sformer_engine/pytorch/dot_product_attention/inference.py
+14
-14
transformer_engine/pytorch/jit.py
transformer_engine/pytorch/jit.py
+22
-2
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+7
-0
transformer_engine/pytorch/module/grouped_linear.py
transformer_engine/pytorch/module/grouped_linear.py
+22
-12
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+58
-81
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+105
-112
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+15
-11
transformer_engine/pytorch/ops/op.py
transformer_engine/pytorch/ops/op.py
+1
-1
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+5
-4
No files found.
transformer_engine/jax/quantize/helper.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Config module for quantization metadata management
This module provides configuration and helper functions for managing quantization metadata
in JAX, including support for different scaling modes and datatypes.
"""
from
contextlib
import
contextmanager
from
enum
import
Enum
from
typing
import
Optional
,
Tuple
,
Dict
,
Union
import
jax
import
jax.numpy
as
jnp
from
flax.core.frozen_dict
import
FrozenDict
from
transformer_engine_jax
import
DType
from
transformer_engine_jax
import
get_cublasLt_version
from
transformer_engine_jax
import
(
get_cuda_version
,
get_device_compute_capability
,
)
from
transformer_engine.common
import
recipe
from
transformer_engine.jax.sharding
import
global_shard_guard
,
MeshResource
from
.scaling_modes
import
ScalingMode
from
..
import
cpp_extensions
as
tex
__all__
=
[
"QuantizeConfig"
,
"fp8_autocast"
,
"is_fp8_available"
,
"update_collections"
]
_is_fp8_available
=
None
_reason_for_no_fp8
=
""
Collection
=
Union
[
Dict
,
FrozenDict
]
def
_check_delayed_scaling_fp8_support
(
gpu_arch
)
->
Tuple
[
bool
,
str
]:
"""Check if delayed scaling FP8 is supported on the given GPU architecture.
Args:
gpu_arch: The GPU architecture version
Returns:
A tuple of (bool, str) indicating support and any error message
"""
if
gpu_arch
>=
90
:
# hopper and above
return
True
,
""
if
gpu_arch
<
89
:
# pre-ada
return
False
,
"Device compute capability 8.9 or higher required for FP8 execution."
if
get_cublasLt_version
()
<
120103
:
return
False
,
"CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if
get_cuda_version
()
<
12010
:
return
False
,
"Cuda version 12.1 or higher required for FP8 execution on Ada."
return
True
,
""
def
_check_block_scaling_fp8_support
(
gpu_arch
)
->
Tuple
[
bool
,
str
]:
"""Check if block scaling FP8 is supported on the given GPU architecture.
Args:
gpu_arch: The GPU architecture version
Returns:
A tuple of (bool, str) indicating support and any error message
"""
if
gpu_arch
>=
100
:
# blackwell and above
return
True
,
""
if
gpu_arch
<
99
:
# pre-blackwell
return
False
,
"Device compute capability 9.9 or higher required for MXFP8 execution."
if
get_cublasLt_version
()
<
120800
:
return
False
,
"CublasLt version 12.8.0 or higher required for MXFP8 execution."
if
get_cuda_version
()
<
12010
:
return
False
,
"Cuda version 12.8 or higher required for MXFP8 execution."
if
not
tex
.
jax_version_meet_requirement
(
"0.5.3"
):
return
False
,
"Jax version 0.5.3 or higher required for MXFP8 execution."
return
True
,
""
def
_check_fp8_support
(
scaling_mode
,
gpu_id
)
->
Tuple
[
bool
,
str
]:
"""Check if FP8 is supported for the given scaling mode and GPU.
Args:
scaling_mode: The scaling mode to check support for
gpu_id: The ID of the GPU to check
Returns:
A tuple of (bool, str) indicating support and any error message
"""
gpu_arch
=
get_device_compute_capability
(
gpu_id
)
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
return
_check_delayed_scaling_fp8_support
(
gpu_arch
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
return
_check_block_scaling_fp8_support
(
gpu_arch
)
return
(
False
,
"Unsupported scaling_mode!"
)
def
is_fp8_available
(
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
gpu_id
=
None
,
)
->
Tuple
[
bool
,
str
]:
"""Check if FP8 is available for the given scaling mode and GPU.
Args:
scaling_mode: The scaling mode to check availability for (default: DELAYED_TENSOR_SCALING)
gpu_id: Optional GPU ID to check specific device (default: None)
Returns:
A tuple of (bool, str) indicating availability and any error message
"""
if
gpu_id
is
not
None
:
return
_check_fp8_support
(
scaling_mode
,
gpu_id
)
global
_is_fp8_available
,
_reason_for_no_fp8
if
_is_fp8_available
is
None
:
_is_fp8_available
=
{}
_reason_for_no_fp8
=
{}
if
scaling_mode
not
in
_is_fp8_available
:
_is_fp8_available
[
scaling_mode
]
=
True
_reason_for_no_fp8
[
scaling_mode
]
=
""
# JAX doesn't provide the local GPU id.
for
local_gpu_id
in
range
(
len
(
jax
.
local_devices
())):
ret
,
msg
=
_check_fp8_support
(
scaling_mode
,
local_gpu_id
)
if
ret
is
False
:
_is_fp8_available
[
scaling_mode
]
=
ret
_reason_for_no_fp8
[
scaling_mode
]
=
msg
return
ret
,
msg
return
_is_fp8_available
[
scaling_mode
],
_reason_for_no_fp8
[
scaling_mode
]
def
_format2dtypes
(
format_
:
recipe
.
Format
):
"""Convert recipe.Format.dtype to corresponding JAX dtypes.
Args:
format_: The FP8 format to convert
Returns:
A tuple of (forward_dtype, backward_dtype) for the given format
"""
if
format_
==
recipe
.
Format
.
E4M3
:
return
jnp
.
float8_e4m3fn
,
jnp
.
float8_e4m3fn
if
format_
==
recipe
.
Format
.
E5M2
:
return
jnp
.
float8_e5m2
,
jnp
.
float8_e5m2
if
format_
==
recipe
.
Format
.
HYBRID
:
return
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
return
jnp
.
bfloat16
,
jnp
.
bfloat16
class
AmaxComputeAlgo
(
Enum
):
"""Enumeration for AMAX computation algorithms.
Attributes:
MAX: Use maximum value for AMAX computation
MOST_RECENT: Use most recent value for AMAX computation
"""
MAX
=
"max"
MOST_RECENT
=
"most_recent"
def
_get_scaling_mode
(
fp8_recipe
:
recipe
.
Recipe
)
->
ScalingMode
:
"""Convert recipe.Recipe to ScalingMode.
Args:
fp8_recipe: The FP8 recipe to convert
Returns:
The corresponding ScalingMode
Raises:
ValueError: If the recipe type is not supported
"""
if
isinstance
(
fp8_recipe
,
recipe
.
DelayedScaling
):
return
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
if
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
return
ScalingMode
.
NVTE_MXFP8_1D_SCALING
raise
ValueError
(
"Invalid fp8_recipe!"
)
def
update_collections
(
new
:
Collection
,
original
:
Collection
)
->
Collection
:
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert
isinstance
(
original
,
(
dict
,
FrozenDict
))
assert
isinstance
(
new
,
(
dict
,
FrozenDict
))
frozen_original
=
FrozenDict
(
original
)
if
not
isinstance
(
original
,
FrozenDict
)
else
original
for
key
in
new
:
if
key
in
frozen_original
:
frozen_original
,
_
=
frozen_original
.
pop
(
key
)
new_coll
=
FrozenDict
({
**
new
,
**
frozen_original
})
if
not
isinstance
(
original
,
FrozenDict
):
new_coll
=
new_coll
.
unfreeze
()
return
new_coll
class
QuantizeConfig
:
"""Configuration class for quantization settings.
This class manages global quantization settings including FP8 formats,
scaling modes, and accumulation settings.
Attributes:
INITIALIZED: Whether the config has been initialized
MARGIN: Margin value for quantization
COLLECTION_NAME: Name of the collection for quantization metadata
FP8_FORMAT: FP8 format to use
FWD_DTYPE: Forward pass data type
BWD_DTYPE: Backward pass data type
FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
IF_QUANTIZE_2X: Whether 2x quantization is enabled
SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
"""
INITIALIZED
=
False
MARGIN
:
float
=
0.0
COLLECTION_NAME
:
str
=
"quantize_meta"
FP8_FORMAT
:
recipe
.
Format
=
recipe
.
Format
.
HYBRID
FWD_DTYPE
:
DType
=
_format2dtypes
(
recipe
.
Format
.
HYBRID
)[
0
]
BWD_DTYPE
:
DType
=
_format2dtypes
(
recipe
.
Format
.
HYBRID
)[
1
]
FP8_2X_ACC_FPROP
:
bool
=
False
FP8_2X_ACC_DGRAD
:
bool
=
False
FP8_2X_ACC_WGRAD
:
bool
=
False
IF_QUANTIZE_2X
:
bool
=
False
SCALING_MODE
:
ScalingMode
=
ScalingMode
.
NVTE_NO_SCALING
# DelayedScaling
AMAX_HISTORY_LEN
:
int
=
1024
AMAX_COMPUTE_ALGO
:
AmaxComputeAlgo
=
AmaxComputeAlgo
.
MAX
@
staticmethod
def
is_fp8_enabled
():
"""Check if FP8 quantization is enabled.
Returns:
bool: True if quantization is enabled, False otherwise
"""
return
QuantizeConfig
.
INITIALIZED
@
classmethod
def
initialize
(
cls
,
fp8_recipe
:
recipe
.
Recipe
)
->
None
:
"""Initialize the quantization configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls
.
INITIALIZED
=
True
cls
.
MARGIN
=
fp8_recipe
.
margin
cls
.
FP8_FORMAT
=
fp8_recipe
.
fp8_format
cls
.
FWD_DTYPE
,
cls
.
BWD_DTYPE
=
_format2dtypes
(
cls
.
FP8_FORMAT
)
cls
.
SCALING_MODE
=
_get_scaling_mode
(
fp8_recipe
)
cls
.
IF_QUANTIZE_2X
=
True
@
classmethod
def
finalize
(
cls
)
->
None
:
"""Reset the quantization configuration to default values."""
cls
.
INITIALIZED
=
False
cls
.
MARGIN
=
0.0
cls
.
FP8_FORMAT
=
recipe
.
Format
.
HYBRID
cls
.
FWD_DTYPE
,
cls
.
BWD_DTYPE
=
_format2dtypes
(
cls
.
FP8_FORMAT
)
cls
.
SCALING_MODE
=
ScalingMode
.
NVTE_NO_SCALING
cls
.
FP8_2X_ACC_FPROP
=
False
cls
.
FP8_2X_ACC_DGRAD
=
False
cls
.
FP8_2X_ACC_WGRAD
=
False
cls
.
SCALING_MODE
=
ScalingMode
.
NVTE_NO_SCALING
cls
.
IF_QUANTIZE_2X
=
False
# DelayedScaling
cls
.
AMAX_HISTORY_LEN
=
1024
cls
.
AMAX_COMPUTE_ALGO
=
AmaxComputeAlgo
.
MAX
class
DelayedScalingQuantizeConfig
:
"""Configuration class for delayed scaling FP8 recipe.
This class provides specific initialization and finalization for delayed scaling
FP8 quantization mode.
"""
@
staticmethod
def
initialize
(
fp8_recipe
:
recipe
.
Recipe
)
->
None
:
"""Initialize delayed scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
Raises:
AssertionError: If recipe parameters are not supported
"""
assert
fp8_recipe
.
amax_compute_algo
in
[
"max"
,
"most_recent"
,
],
"DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
assert
(
fp8_recipe
.
scaling_factor_compute_algo
is
None
),
"DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert
fp8_recipe
.
reduce_amax
,
"DelayedScaling reduce_amax should be enabled for TE/JAX."
cls
=
QuantizeConfig
cls
.
initialize
(
fp8_recipe
)
cls
.
AMAX_HISTORY_LEN
=
fp8_recipe
.
amax_history_len
string_to_amax_compute_algo
=
{
"max"
:
AmaxComputeAlgo
.
MAX
,
"most_recent"
:
AmaxComputeAlgo
.
MOST_RECENT
,
}
cls
.
AMAX_COMPUTE_ALGO
=
string_to_amax_compute_algo
[
fp8_recipe
.
amax_compute_algo
]
cls
.
FP8_2X_ACC_DGRAD
=
True
cls
.
FP8_2X_ACC_WGRAD
=
True
@
staticmethod
def
finalize
()
->
None
:
"""Reset the delayed scaling configuration."""
QuantizeConfig
.
finalize
()
class
BlockScalingQuantizeConfig
:
"""Configuration class for block scaling FP8 recipe.
This class provides specific initialization and finalization for block scaling
FP8 quantization mode.
"""
@
staticmethod
def
initialize
(
fp8_recipe
:
recipe
.
Recipe
)
->
None
:
"""Initialize block scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls
=
QuantizeConfig
cls
.
initialize
(
fp8_recipe
)
cls
.
AMAX_HISTORY_LEN
=
0
@
staticmethod
def
finalize
()
->
None
:
"""Reset the block scaling configuration."""
QuantizeConfig
.
finalize
()
@
contextmanager
def
fp8_autocast
(
enabled
:
bool
=
False
,
fp8_recipe
:
Optional
[
recipe
.
Recipe
]
=
None
,
mesh_resource
:
Optional
[
MeshResource
]
=
None
,
)
->
None
:
r
"""Context manager for FP8 automatic mixed precision.
This context manager enables FP8 quantization for the duration of its context.
.. code-block:: python
mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer()
with partitioning.axis_rules(rules):
pjit(transformer.init, ...)(...)
.. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`,
and :attr:`amax_compute_algo` (with value 'max' and 'most_recent') in
recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
will trigger an assertion.
Parameters
----------
enabled: bool, default = False
Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training.
mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used.
"""
if
fp8_recipe
is
None
:
fp8_recipe
=
recipe
.
DelayedScaling
()
if
mesh_resource
is
None
:
mesh_resource
=
MeshResource
()
Config
=
DelayedScalingQuantizeConfig
if
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
Config
=
BlockScalingQuantizeConfig
try
:
with
global_shard_guard
(
mesh_resource
):
if
enabled
:
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
_get_scaling_mode
(
fp8_recipe
))
assert
fp8_available
,
reason_for_no_fp8
Config
.
initialize
(
fp8_recipe
)
yield
finally
:
Config
.
finalize
()
transformer_engine/jax/quantize/metadata.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Metadata classes for quantization in JAX.
This module provides classes for managing quantization metadata, including
scale factors and amax history for different tensor types.
"""
from
dataclasses
import
dataclass
import
jax.numpy
as
jnp
__all__
=
[
"QuantizeMeta"
,
"QuantizeMetaSet"
]
@
dataclass
class
QuantizeMeta
:
"""Metadata for quantization parameters.
Attributes:
scale: The scaling factor for quantization
amax_history: History of maximum absolute values
"""
scale
:
jnp
.
ndarray
amax_history
:
jnp
.
ndarray
@
dataclass
class
QuantizeMetaSet
:
"""Set of quantization metadata for different tensor types.
Attributes:
x: Quantization metadata for input tensors
kernel: Quantization metadata for kernel tensors
grad: Quantization metadata for gradient tensors
"""
x
:
QuantizeMeta
kernel
:
QuantizeMeta
grad
:
QuantizeMeta
transformer_engine/jax/quantize/quantizer.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor quantization classes for TE/JAX.
This module provides classes and utilities for quantizing tensors in JAX.
"""
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
functools
import
partial
from
typing
import
Union
,
Optional
import
jax
import
jax.numpy
as
jnp
from
jax.tree_util
import
register_pytree_node_class
from
transformer_engine_jax
import
QuantizeAxis
from
.scaling_modes
import
ScalingMode
from
.tensor
import
ScaledTensor1x
,
ScaledTensor2x
,
ScaledTensorFactory
from
.helper
import
(
QuantizeConfig
,
AmaxComputeAlgo
,
)
__all__
=
[
"QuantizeAxis"
,
"Quantizer"
,
"QuantizerSet"
,
"DelayedScaleQuantizer"
,
"BlockScaleQuantizer"
,
"QuantizerFactory"
,
"noop_quantizer_set"
,
]
@
register_pytree_node_class
@
dataclass
class
Quantizer
(
ABC
):
"""Base class for quantizers.
This abstract class defines the interface for tensor quantization, providing
methods for quantization and scale management.
Attributes:
q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization
q_axis: The quantization axis (row-wise, column-wise, or both)
"""
q_dtype
:
jnp
.
dtype
scaling_mode
:
ScalingMode
q_axis
:
QuantizeAxis
def
tree_flatten
(
self
):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children
=
()
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_axis
)
return
(
children
,
aux_data
)
@
classmethod
def
tree_unflatten
(
cls
,
aux_data
,
children
):
"""Reconstruct a quantizer from its flattened representation.
Args:
aux_data: Auxiliary data containing quantizer parameters
children: Unused children data
Returns:
A reconstructed Quantizer instance
"""
return
cls
(
*
aux_data
,
*
children
)
def
update
(
self
,
*
args
,
**
kwargs
):
"""Update quantizer state (no-op in base class)."""
del
args
,
kwargs
def
is_2x2x
(
self
)
->
bool
:
"""Check if quantizer uses both row-wise and column-wise quantization.
Returns:
True if using both row-wise and column-wise quantization
"""
return
self
.
q_axis
==
QuantizeAxis
.
ROWWISE_COLWISE
@
abstractmethod
def
get_layout
(
self
)
->
str
:
"""Get the data layout.
Returns:
Data layout in string format
"""
@
abstractmethod
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
)
->
ScaledTensor1x
:
"""Core quantization function to be implemented by subclasses.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype
Returns:
A ScaledTensor1x containing the quantized data
"""
def
quantize
(
self
,
x
,
is_rowwise
=
False
,
is_colwise
=
False
,
dq_dtype
=
None
):
"""Quantize a tensor using the internal _quantize_func().
Args:
x: Input tensor to quantize
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
if
(
is_rowwise
and
is_colwise
)
or
self
.
is_2x2x
():
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
)
colwise_tensor
=
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
if
is_colwise
:
return
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
)
return
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
)
def
get_scale_shapes
(
self
,
data_shape
,
is_padded
=
True
):
"""Get shapes for scale tensors.
Args:
data_shape: Shape of the input tensor
is_padded: Whether to use padded shapes
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return
self
.
scaling_mode
.
get_scale_shape_2x
(
data_shape
,
is_padded
)
def
get_scale_dtype
(
self
):
"""Get the data type for scale tensors.
Returns:
The data type for scale tensors
"""
return
self
.
scaling_mode
.
get_scale_dtype
()
@
register_pytree_node_class
@
dataclass
class
DelayedScaleQuantizer
(
Quantizer
):
"""Quantizer implementation using delayed scaling.
This quantizer uses delayed scaling mode with float32 scales and maintains
a history of maximum absolute values for dynamic scaling.
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode
:
ScalingMode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
q_axis
:
QuantizeAxis
=
QuantizeAxis
.
ROWWISE_COLWISE
scale
:
jnp
.
ndarray
=
field
(
default_factory
=
lambda
:
jnp
.
ones
((
1
,),
jnp
.
float32
))
amax_history
:
jnp
.
ndarray
=
field
(
default_factory
=
lambda
:
jnp
.
zeros
((
QuantizeConfig
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
)
)
def
tree_flatten
(
self
):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children
=
(
self
.
scale
,
self
.
amax_history
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_axis
)
return
(
children
,
aux_data
)
def
get_layout
(
self
)
->
str
:
"""Get the data layout string.
Returns:
Data layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
layout
=
"NT"
if
self
.
q_axis
==
QuantizeAxis
.
ROWWISE_COLWISE
:
return
layout
if
self
.
q_axis
==
QuantizeAxis
.
ROWWISE
:
return
layout
[
0
]
if
self
.
q_axis
==
QuantizeAxis
.
COLWISE
:
return
layout
[
1
]
raise
ValueError
(
f
"Invalid q_axis:
{
self
.
q_axis
}
"
)
def
_quantize_func
(
self
,
x
:
jnp
.
ndarray
,
is_colwise
=
False
,
dq_dtype
=
None
)
->
ScaledTensor1x
:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
compute_dtype
=
self
.
scale
.
dtype
dtype_max
=
(
jnp
.
finfo
(
self
.
q_dtype
).
max
).
astype
(
compute_dtype
)
scaled_x
=
x
.
astype
(
compute_dtype
)
*
self
.
scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
# dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x
=
jnp
.
clip
(
scaled_x
,
-
dtype_max
,
dtype_max
).
astype
(
self
.
q_dtype
)
scale_inv
=
1.0
/
self
.
scale
self
.
update
(
jnp
.
max
(
jnp
.
abs
(
x
)).
reshape
((
1
,)))
return
ScaledTensorFactory
.
create_1x
(
data
=
clipped_scaled_x
,
scale_inv
=
scale_inv
,
scaling_mode
=
self
.
scaling_mode
,
dq_dtype
=
dq_dtype
,
)
def
quantize
(
self
,
x
,
is_rowwise
:
bool
=
None
,
is_colwise
:
bool
=
None
,
dq_dtype
=
None
):
"""Quantize a tensor using the internal _quantize_func().
Args:
x: Input tensor to quantize
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
is_rowwise
=
(
is_rowwise
if
is_rowwise
is
not
None
else
(
self
.
q_axis
==
QuantizeAxis
.
ROWWISE
or
self
.
is_2x2x
())
)
is_colwise
=
(
is_colwise
if
is_colwise
is
not
None
else
(
self
.
q_axis
==
QuantizeAxis
.
COLWISE
or
self
.
is_2x2x
())
)
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
)
colwise_tensor
=
None
if
is_colwise
:
colwise_tensor
=
ScaledTensorFactory
.
create_1x
(
data
=
jnp
.
transpose
(
rowwise_tensor
.
data
,
(
-
1
,
*
range
(
rowwise_tensor
.
data
.
ndim
-
1
))),
scale_inv
=
rowwise_tensor
.
scale_inv
,
scaling_mode
=
self
.
scaling_mode
,
dq_dtype
=
dq_dtype
,
is_colwise
=
True
,
layout
=
"T"
,
)
if
is_colwise
and
is_rowwise
:
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
if
is_colwise
:
return
colwise_tensor
return
rowwise_tensor
@
staticmethod
@
jax
.
jit
def
_update_amax_history
(
amax_history
,
new_amax
):
"""Update AMAX history with new maximum value.
Args:
amax_history: Current AMAX history
new_amax: New maximum value to add
Returns:
Updated AMAX history
"""
amax_history
=
amax_history
.
at
[
0
].
set
(
new_amax
[
0
])
return
amax_history
@
staticmethod
@
partial
(
jax
.
jit
,
static_argnums
=
(
2
,))
def
_compute_scale
(
amax_history
,
scale
,
q_dtype
):
"""Compute new scale based on AMAX history.
Args:
amax_history: History of maximum absolute values
scale: Current scale
q_dtype: Quantization data type
Returns:
Updated scale value
"""
# 2. Calculate the current scale
fp8_max
=
jnp
.
astype
(
jnp
.
finfo
(
q_dtype
).
max
,
jnp
.
float32
)
if
QuantizeConfig
.
AMAX_COMPUTE_ALGO
is
AmaxComputeAlgo
.
MAX
:
amax
=
jnp
.
max
(
amax_history
,
axis
=-
1
,
keepdims
=
True
)
else
:
amax
=
amax_history
[
0
:
1
]
sf
=
(
fp8_max
/
amax
)
/
(
2
**
QuantizeConfig
.
MARGIN
)
sf
=
jnp
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
jnp
.
where
(
jnp
.
isfinite
(
amax
),
sf
,
scale
)
scale
=
scale
.
at
[
0
].
set
(
sf
[
0
])
return
scale
@
staticmethod
@
jax
.
jit
def
_roll_and_reset_amax_history
(
amax_history
):
"""Roll AMAX history and reset first element.
Args:
amax_history: Current AMAX history
Returns:
Updated AMAX history
"""
updated_amax_history
=
jnp
.
roll
(
amax_history
,
-
1
,
-
1
)
amax_history
=
updated_amax_history
.
at
[
0
].
set
(
0.0
)
return
amax_history
def
update
(
self
,
new_amax
:
jnp
.
ndarray
):
"""Update AMAX history and compute new scale.
Args:
new_amax: New maximum absolute value to add to history
"""
amax_history
=
self
.
_update_amax_history
(
self
.
amax_history
,
new_amax
)
self
.
scale
=
self
.
_compute_scale
(
amax_history
,
self
.
scale
,
self
.
q_dtype
)
self
.
amax_history
=
self
.
_roll_and_reset_amax_history
(
amax_history
)
@
register_pytree_node_class
@
dataclass
class
BlockScaleQuantizer
(
Quantizer
):
"""Quantizer implementation using block-based scaling.
This quantizer uses block scaling mode with FP8 scales and block-based
quantization for improved efficiency.
Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
"""
scaling_mode
:
ScalingMode
=
ScalingMode
.
NVTE_MXFP8_1D_SCALING
q_axis
:
QuantizeAxis
=
QuantizeAxis
.
ROWWISE_COLWISE
def
get_layout
(
self
)
->
str
:
"""Get the data layout string.
Returns:
Data layout in string format
"""
if
self
.
is_2x2x
():
return
"NN"
return
"N"
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
)
->
ScaledTensor1x
:
"""Quantize function helper for block scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
Returns:
A ScaledTensor1x containing the quantized data
"""
# TODO(Phuong): use quantize_func from JAX
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
x_shape
=
x
.
shape
scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
x_shape
,
is_colwise
,
is_padded
=
False
)
scale_dtype
=
self
.
scaling_mode
.
get_scale_dtype
()
x
=
x
.
reshape
(
*
x_shape
[:
-
2
],
scale_shape
[
-
2
],
int
(
x_shape
[
-
2
]
/
scale_shape
[
-
2
]),
scale_shape
[
-
1
],
int
(
x_shape
[
-
1
]
/
scale_shape
[
-
1
]),
)
amax
=
jnp
.
max
(
jnp
.
abs
(
x
),
axis
=
(
-
3
,
-
1
),
keepdims
=
True
)
MAX
=
jnp
.
finfo
(
self
.
q_dtype
).
max
.
astype
(
jnp
.
float32
)
scales
=
amax
.
astype
(
jnp
.
float32
)
/
MAX
scales_q
=
self
.
_cast_to_e8m0_with_rounding_up
(
scales
)
scaled_x
=
x
/
self
.
_e8m0_to_dtype
(
scales_q
,
jnp
.
float32
)
clipped_x
=
jnp
.
clip
(
scaled_x
,
-
MAX
,
MAX
)
x_q
=
clipped_x
.
astype
(
self
.
q_dtype
).
reshape
(
x_shape
)
scales_q
=
scales_q
.
reshape
(
scale_shape
).
view
(
scale_dtype
)
return
ScaledTensorFactory
.
create_1x
(
x_q
,
scales_q
,
self
.
scaling_mode
,
is_colwise
=
is_colwise
,
dq_dtype
=
dq_dtype
,
)
def
_cast_to_e8m0_with_rounding_up
(
self
,
scales
):
"""Cast scales to E8M0 format with rounding up.
Args:
scales: Input scales to convert
Returns:
Scales in E8M0 format
"""
temp
=
scales
.
astype
(
jnp
.
float32
).
view
(
jnp
.
uint32
)
exp
=
temp
>>
23
mant
=
temp
&
0x7FFFFF
is_ru
=
jnp
.
logical_and
(
jnp
.
logical_and
((
mant
>
0
),
(
exp
!=
0xFE
)),
~
jnp
.
logical_and
((
exp
==
0
),
(
mant
<=
0x400000
)),
)
exp
=
jnp
.
where
(
is_ru
,
exp
+
1
,
exp
)
new_scales
=
exp
.
astype
(
jnp
.
uint8
)
return
new_scales
def
_e8m0_to_dtype
(
self
,
x
,
dtype
):
"""Convert E8M0 format to specified data type.
Args:
x: Input in E8M0 format
dtype: Target data type
Returns:
Converted values in target data type
"""
temp
=
x
.
astype
(
jnp
.
uint32
)
exp
=
temp
<<
23
new_x
=
exp
.
view
(
jnp
.
float32
)
near_zero_value
=
2
**-
15
if
dtype
==
jnp
.
float16
else
2
**-
127
new_x
=
jnp
.
where
(
new_x
==
0
,
jnp
.
array
(
near_zero_value
,
jnp
.
float32
),
new_x
)
return
new_x
.
astype
(
dtype
)
@
register_pytree_node_class
@
dataclass
class
QuantizerSet
:
"""Set of quantizers for different tensor types.
This class manages quantizers for input tensors, kernel tensors, and
gradient tensors.
Attributes:
x: Quantizer for input tensors
kernel: Quantizer for kernel tensors
dgrad: Quantizer for gradient tensors
"""
x
:
Optional
[
Quantizer
]
kernel
:
Optional
[
Quantizer
]
dgrad
:
Optional
[
Quantizer
]
def
tree_flatten
(
self
):
"""Flatten the quantizer set for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children
=
(
self
.
x
,
self
.
kernel
,
self
.
dgrad
)
aux_data
=
()
return
(
children
,
aux_data
)
@
classmethod
def
tree_unflatten
(
cls
,
aux_data
,
children
):
"""Reconstruct a quantizer set from its flattened representation.
Args:
aux_data: Unused auxiliary data
children: Tuple of quantizers
Returns:
A reconstructed QuantizerSet instance
"""
return
cls
(
*
aux_data
,
*
children
)
@
dataclass
class
QuantizerFactory
:
"""Factory class for creating quantizers.
This class provides static methods to create individual quantizers and
sets of quantizers with various configurations.
"""
quantizer_type_map
=
{
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
DelayedScaleQuantizer
,
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
BlockScaleQuantizer
,
}
@
staticmethod
def
create
(
n_quantizers
:
int
=
1
,
scaling_mode
:
ScalingMode
=
None
,
q_dtype
:
jnp
.
dtype
=
None
,
q_axis
:
QuantizeAxis
=
None
,
**
kwargs
,
)
->
Quantizer
:
"""Create one or more quantizers with specified parameters.
Args:
n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use
q_dtype: Quantization data type
q_axis: Quantization axis
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer or tuple of quantizers
"""
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING
if
scaling_mode
in
(
ScalingMode
.
NVTE_NO_SCALING
,
ScalingMode
.
NVTE_INVALID_SCALING
):
quantizers
=
[
None
]
*
n_quantizers
else
:
quantizers
=
[]
for
_
in
range
(
n_quantizers
):
quantizer_type
=
QuantizerFactory
.
quantizer_type_map
.
get
(
scaling_mode
)
quantizers
.
append
(
quantizer_type
(
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
,
**
kwargs
)
)
return
quantizers
[
0
]
if
len
(
quantizers
)
==
1
else
tuple
(
quantizers
)
@
staticmethod
def
_create_set
(
scaling_mode
,
fwd_dtype
,
bwd_dtype
,
is_2x2x
,
**
kwargs
)
->
QuantizerSet
:
"""Create a set of quantizers for forward and backward passes.
Args:
scaling_mode: Scaling mode to use
fwd_dtype: Data type for forward pass
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
**kwargs: Additional arguments for quantizer initialization
Returns:
A QuantizerSet instance
"""
if
is_2x2x
:
q_axis_x
=
q_axis_kernel
=
q_axis_dgrad
=
QuantizeAxis
.
ROWWISE_COLWISE
else
:
q_axis_x
=
QuantizeAxis
.
ROWWISE
q_axis_kernel
=
QuantizeAxis
.
COLWISE
q_axis_dgrad
=
None
if
"quantize_meta_set"
in
kwargs
:
quantize_meta_set
=
kwargs
.
get
(
"quantize_meta_set"
)
args_x
=
{
"scale"
:
quantize_meta_set
.
x
.
scale
,
"amax_history"
:
quantize_meta_set
.
x
.
amax_history
,
}
args_kernel
=
{
"scale"
:
quantize_meta_set
.
kernel
.
scale
,
"amax_history"
:
quantize_meta_set
.
kernel
.
amax_history
,
}
args_grad
=
{
"scale"
:
quantize_meta_set
.
grad
.
scale
,
"amax_history"
:
quantize_meta_set
.
grad
.
amax_history
,
}
else
:
args_x
=
args_kernel
=
args_grad
=
{}
q_x
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_axis_x
,
**
args_x
)
q_kernel
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_axis_kernel
,
**
args_kernel
)
q_dgrad
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
bwd_dtype
,
q_axis_dgrad
,
**
args_grad
)
return
QuantizerSet
(
x
=
q_x
,
kernel
=
q_kernel
,
dgrad
=
q_dgrad
)
@
staticmethod
def
create_set
(
n_quantizer_sets
:
int
=
1
,
scaling_mode
:
ScalingMode
=
None
,
fwd_dtype
:
jnp
.
dtype
=
None
,
bwd_dtype
:
jnp
.
dtype
=
None
,
is_2x2x
:
bool
=
None
,
**
kwargs
,
)
->
tuple
[
Union
[
tuple
[
Quantizer
],
None
]]:
"""Create one or more sets of quantizers.
Args:
n_quantizer_sets: Number of quantizer sets to create
scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer set or tuple of quantizer sets
"""
scaling_mode
=
scaling_mode
or
QuantizeConfig
.
SCALING_MODE
fwd_dtype
=
fwd_dtype
or
QuantizeConfig
.
FWD_DTYPE
bwd_dtype
=
bwd_dtype
or
QuantizeConfig
.
BWD_DTYPE
is_2x2x
=
is_2x2x
or
QuantizeConfig
.
IF_QUANTIZE_2X
q_set
=
[]
for
_
in
range
(
n_quantizer_sets
):
q_set
.
append
(
QuantizerFactory
.
_create_set
(
scaling_mode
,
fwd_dtype
,
bwd_dtype
,
is_2x2x
,
**
kwargs
)
)
return
q_set
[
0
]
if
len
(
q_set
)
==
1
else
tuple
(
q_set
)
noop_quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
ScalingMode
.
NVTE_NO_SCALING
)
transformer_engine/jax/quantize/scaling_modes.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Scaling mode implementations for quantization in JAX.
This module provides implementations of different scaling modes for tensor quantization,
including delayed scaling and block scaling strategies.
"""
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Tuple
,
Dict
from
functools
import
reduce
import
operator
from
jax.tree_util
import
register_pytree_node_class
import
jax.numpy
as
jnp
__all__
=
[
"ScalingMode"
]
class
ScalingModeMetadataImpl
(
ABC
):
"""Base class for scaling mode implementations.
This abstract class defines the interface for different scaling mode implementations,
providing methods to get scale data types and shapes.
"""
@
abstractmethod
def
get_scale_dtype
(
self
)
->
jnp
.
dtype
:
"""Get the data type for scale tensors.
Returns:
The data type used for scale tensors
"""
@
abstractmethod
def
get_scale_shape
(
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
)
->
Tuple
[
int
,
...]:
"""Get the shape for scale tensors.
Args:
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
Returns:
The shape for scale tensors
"""
class
DelayedScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
"""Implementation for delayed scaling mode.
This implementation provides metadata for delayed scaling mode, including scale data type and shape.
"""
def
get_scale_dtype
(
self
)
->
jnp
.
dtype
:
"""Get the data type for scale tensors in delayed scaling.
Returns:
The data type used for scale tensors (float32)
"""
return
jnp
.
float32
def
get_scale_shape
(
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
)
->
Tuple
[
int
,
...]:
"""Get the shape for scale tensors in delayed scaling.
Args:
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
Returns:
The shape for scale tensors - (1,)
"""
del
data_shape
,
is_colwise
return
(
1
,)
class
BlockScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
"""Implementation for block scaling mode.
This implementation provides metadata for block scaling mode, which uses
block-based scaling with specific alignment requirements.
Attributes:
_block_dims: Dimensions of the scaling blocks
_block_alignment: Alignment requirements for blocks
"""
def
__init__
(
self
,
block_dims
:
Tuple
[
int
]):
"""Initialize block scaling mode implementation.
Args:
block_dims: Dimensions of the scaling blocks
"""
self
.
_block_dims
=
block_dims
self
.
_block_alignment
=
(
128
,
4
)
def
get_scale_dtype
(
self
)
->
jnp
.
dtype
:
"""Get the data type for scale tensors in block scaling.
Returns:
The data type used for scale tensors (float8_e8m0fnu)
"""
return
jnp
.
float8_e8m0fnu
def
get_scale_shape
(
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
)
->
Tuple
[
int
,
...]:
"""Get the shape for scale tensors in block scaling.
Args:
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
Returns:
The shape for scale tensors
"""
block_alignment
=
self
.
_block_alignment
if
is_padded
else
(
1
,
1
)
if
is_colwise
:
block_y
,
block_x
=
self
.
_block_dims
alignment_y
,
alignment_x
=
block_alignment
else
:
block_x
,
block_y
=
self
.
_block_dims
alignment_x
,
alignment_y
=
block_alignment
seq_axis
=
len
(
data_shape
)
-
2
assert
(
data_shape
[
seq_axis
]
%
block_x
==
0
),
f
"Input data of shape
{
data_shape
}
should be padded by
{
block_x
}
in axes=
{
seq_axis
}
"
assert
(
data_shape
[
-
1
]
%
block_y
==
0
),
f
"Input data of shape
{
data_shape
}
should be padded by
{
block_y
}
in axis -1"
# NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1
n_block_seq
=
data_shape
[
seq_axis
]
//
block_x
n_block_y
=
data_shape
[
-
1
]
//
block_y
n_flat_first_dim
=
reduce
(
operator
.
mul
,
data_shape
[:
seq_axis
],
1
)
*
n_block_seq
# Padding
n_flat_first_dim
=
((
n_flat_first_dim
+
alignment_x
-
1
)
//
alignment_x
)
*
alignment_x
n_block_y
=
((
n_block_y
+
alignment_y
-
1
)
//
alignment_y
)
*
alignment_y
out_shape
=
()
for
i
in
range
(
seq_axis
):
d
=
data_shape
[
i
]
out_shape
+=
(
d
,)
assert
n_flat_first_dim
%
d
==
0
n_flat_first_dim
//=
d
out_shape
+=
(
n_flat_first_dim
,
n_block_y
)
return
out_shape
# (Phuong: Map the NVTEScalingMode value to the ScalingMode
@
dataclass
(
frozen
=
True
)
@
register_pytree_node_class
class
ScalingMode
(
Enum
):
"""Enumeration of tensor scaling modes with their corresponding metadata implementations.
This class defines the available scaling modes for tensor quantization:
- NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NVTE_INVALID_SCALING: Invalid scaling mode
- NVTE_NO_SCALING: No scaling applied
"""
NVTE_DELAYED_TENSOR_SCALING
=
0
NVTE_MXFP8_1D_SCALING
=
1
NVTE_INVALID_SCALING
=
2
NVTE_NO_SCALING
=
3
def
_get_impl
(
self
)
->
ScalingModeMetadataImpl
:
"""Get the implementation for this scaling mode.
Returns:
The scaling mode implementation
Raises:
ValueError: If the scaling mode is invalid
"""
impl
=
SCALING_MODES_TO_IMPL
.
get
(
self
)
if
impl
is
None
:
raise
ValueError
(
"Invalid scaling mode"
)
return
impl
def
get_scale_dtype
(
self
):
"""Get the data type for scale tensors in this mode.
Returns:
The data type for scale tensors
"""
return
self
.
_get_impl
().
get_scale_dtype
()
def
get_scale_shape_2x
(
self
,
data_shape
,
is_padded
=
True
)
->
Tuple
[
Tuple
[
int
]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
rowwise_scale_shape
=
self
.
get_scale_shape
(
data_shape
,
is_colwise
=
False
,
is_padded
=
is_padded
)
colwise_scale_shape
=
self
.
get_scale_shape
(
data_shape
,
is_colwise
=
True
,
is_padded
=
is_padded
)
return
(
rowwise_scale_shape
,
colwise_scale_shape
)
def
get_scale_shape
(
self
,
data_shape
,
is_colwise
,
is_padded
=
True
)
->
Tuple
[
int
]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
Returns:
The shape for scale tensors
"""
return
self
.
_get_impl
().
get_scale_shape
(
data_shape
,
is_colwise
,
is_padded
)
def
__eq__
(
self
,
other
):
"""Compare this scaling mode with another.
Args:
other: The other scaling mode to compare with
Returns:
True if the modes are equal, False otherwise
"""
if
not
isinstance
(
other
,
ScalingMode
):
return
False
return
self
.
value
==
other
.
value
def
tree_flatten
(
self
):
"""Flatten this scaling mode for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
return
(),
(
self
.
value
)
@
classmethod
def
tree_unflatten
(
cls
,
aux_data
,
_children
):
"""Reconstruct a scaling mode from its flattened representation.
Args:
aux_data: Auxiliary data containing the mode value
_children: Unused children data
Returns:
A reconstructed ScalingMode instance
"""
return
cls
(
aux_data
)
SCALING_MODES_TO_IMPL
:
Dict
[
ScalingMode
,
ScalingModeMetadataImpl
]
=
{
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
DelayedScalingModeMetadataImpl
(),
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
BlockScalingModeMetadataImpl
(
block_dims
=
(
1
,
32
)),
# WAR
ScalingMode
.
NVTE_NO_SCALING
:
DelayedScalingModeMetadataImpl
(),
}
transformer_engine/jax/quantize/tensor.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX
This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Tuple
from
abc
import
ABC
,
abstractmethod
import
jax.numpy
as
jnp
from
jax.tree_util
import
register_pytree_node_class
from
transformer_engine_jax
import
QuantizeAxis
from
.scaling_modes
import
ScalingMode
from
.dequantizer
import
Dequantizer
from
..sharding
import
(
with_sharding_constraint_by_logical_axes
as
original_with_sharding_constraint_by_logical_axes
,
)
__all__
=
[
"ScaledTensor"
,
"ScaledTensor1x"
,
"ScaledTensor2x"
,
"ScaledTensorFactory"
,
"with_sharding_constraint_by_logical_axes"
,
]
@
register_pytree_node_class
@
dataclass
class
ScaledTensor
(
ABC
):
"""Abstract base class for scaled tensors.
This class defines the interface for all scaled tensor implementations,
providing methods for dequantization and accessing row/column-wise components.
"""
@
classmethod
def
tree_unflatten
(
cls
,
aux_data
,
children
):
"""Reconstructs the tensor from its flattened representation.
Args:
aux_data: Auxiliary data needed for reconstruction
children: The flattened tensor components
Returns:
A reconstructed tensor instance
"""
return
cls
(
*
children
,
*
aux_data
)
@
abstractmethod
def
dequantize
(
self
):
"""Dequantizes the tensor back to its original precision.
Returns:
The dequantized tensor
"""
@
abstractmethod
def
get_rowwise_tensor
(
self
):
"""Returns the row-wise component of the tensor.
Returns:
The row-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support row-wise access
"""
@
abstractmethod
def
get_colwise_tensor
(
self
):
"""Returns the column-wise component of the tensor.
Returns:
The column-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support column-wise access
"""
@
register_pytree_node_class
@
dataclass
class
ScaledTensor1x
(
ScaledTensor
):
"""Single-scale quantized tensor implementation.
This class represents a tensor quantized with a single scaling factor,
supporting both row-wise and column-wise quantization modes.
Attributes:
data: The quantized tensor data
scale_inv: The inverse scaling factors
scaling_mode: The scaling mode used for quantization
dq_dtype: The data type for dequantized values
_dq_func: The dequantization function
is_colwise: Whether the tensor uses column-wise quantization
layout: The layout specification for the tensor
"""
data
:
jnp
.
ndarray
scale_inv
:
jnp
.
ndarray
scaling_mode
:
ScalingMode
dq_dtype
:
jnp
.
dtype
_dq_func
:
Callable
is_colwise
:
bool
layout
:
str
def
__post_init__
(
self
):
"""Validates and adjusts the scale_inv shape after initialization.
Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary.
"""
expected_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
True
)
expected_unpadded_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
False
)
if
self
.
scale_inv
.
shape
!=
expected_scale_shape
:
assert
self
.
scale_inv
.
shape
==
expected_unpadded_scale_shape
,
(
f
"Unexpected scale_inv shape!
\n
Expect
{
expected_scale_shape
}
for padded"
f
" scale_inv or
{
expected_unpadded_scale_shape
}
for unpadded scale_inv, got"
f
"
{
self
.
scale_inv
.
shape
}
"
)
pad_width
=
tuple
(
(
0
,
a
-
b
)
for
a
,
b
in
zip
(
expected_scale_shape
,
expected_unpadded_scale_shape
)
)
# This actually pad scale_inv with nan, should we pad it with 127 directly instead?
self
.
scale_inv
=
jnp
.
pad
(
self
.
scale_inv
,
pad_width
=
pad_width
,
mode
=
"constant"
,
constant_values
=
0
)
def
tree_flatten
(
self
):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children
=
(
self
.
data
,
self
.
scale_inv
)
aux_data
=
(
self
.
scaling_mode
,
self
.
dq_dtype
,
self
.
_dq_func
,
self
.
is_colwise
,
self
.
layout
)
return
(
children
,
aux_data
)
def
dequantize
(
self
):
"""Dequantizes the tensor using the stored dequantization function.
Returns:
The dequantized tensor
"""
return
self
.
_dq_func
(
self
)
def
get_rowwise_tensor
(
self
):
"""Returns the tensor if it's row-wise quantized.
Returns:
The row-wise tensor
Raises:
ValueError: If called on a column-wise quantized tensor
"""
if
not
self
.
is_colwise
:
return
self
raise
ValueError
(
"Calling get_rowwise_tensor() from a colwise ScaledTensor1x!"
)
def
get_colwise_tensor
(
self
):
"""Returns the tensor if it's column-wise quantized.
Returns:
The column-wise tensor
Raises:
ValueError: If called on a row-wise quantized tensor
"""
if
self
.
is_colwise
:
return
self
raise
ValueError
(
"Calling get_colwise_tensor() from a rowwise ScaledTensor1x!"
)
@
register_pytree_node_class
@
dataclass
class
ScaledTensor2x
(
ScaledTensor
):
"""Double-scale quantized tensor implementation.
This class represents a tensor quantized with both row-wise and column-wise scaling factors.
Attributes:
rowwise_tensor: The row-wise quantized component
colwise_tensor: The column-wise quantized component
"""
rowwise_tensor
:
ScaledTensor1x
colwise_tensor
:
ScaledTensor1x
def
tree_flatten
(
self
):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children
=
(
self
.
rowwise_tensor
,
self
.
colwise_tensor
)
aux_data
=
()
return
(
children
,
aux_data
)
def
dequantize
(
self
):
"""Dequantizes the tensor using the row-wise component's dequantization.
Returns:
The dequantized tensor
"""
return
self
.
rowwise_tensor
.
dequantize
()
def
get_rowwise_tensor
(
self
):
"""Returns the row-wise quantized component.
Returns:
The row-wise tensor component
"""
return
self
.
rowwise_tensor
def
get_colwise_tensor
(
self
):
"""Returns the column-wise quantized component.
Returns:
The column-wise tensor component
"""
return
self
.
colwise_tensor
@
dataclass
class
ScaledTensorFactory
:
"""Factory class for creating scaled tensor instances.
Provides static methods to create both single-scale (1x) and double-scale (2x)
quantized tensors with various configurations.
"""
@
staticmethod
def
create_1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
=
jnp
.
bfloat16
,
is_colwise
=
False
,
layout
=
"N"
):
"""Creates a single-scale quantized tensor.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False)
layout: The layout specification (default: "N")
Returns:
A ScaledTensor1x instance
"""
dq_func
=
Dequantizer
.
funcs
.
get
(
scaling_mode
)
return
ScaledTensor1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
dq_func
,
is_colwise
,
layout
)
@
staticmethod
def
create_2x
(
data
,
scale_inv
,
colwise_data
,
colwise_scale_inv
,
scaling_mode
,
dq_dtype
=
jnp
.
bfloat16
,
layout
=
"NN"
,
):
"""Creates a double-scale quantized tensor.
Args:
data: The row-wise quantized data
scale_inv: The row-wise inverse scaling factors
colwise_data: The column-wise quantized data
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
Returns:
A ScaledTensor2x instance
"""
dq_func
=
Dequantizer
.
funcs
.
get
(
scaling_mode
)
rowwise_tensor
=
ScaledTensor1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
dq_func
,
is_colwise
=
False
,
layout
=
layout
[
0
],
)
colwise_tensor
=
ScaledTensor1x
(
colwise_data
,
colwise_scale_inv
,
scaling_mode
,
dq_dtype
,
dq_func
,
is_colwise
=
True
,
layout
=
layout
[
1
],
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
@
staticmethod
def
create
(
data
:
jnp
.
ndarray
,
scale_inv
:
jnp
.
ndarray
,
colwise_data
:
jnp
.
ndarray
,
colwise_scale_inv
:
jnp
.
ndarray
,
scaling_mode
:
ScalingMode
,
dq_dtype
:
jnp
.
dtype
=
jnp
.
bfloat16
,
layout
:
str
=
"NN"
,
q_axis
:
QuantizeAxis
=
QuantizeAxis
.
ROWWISE
,
):
"""Creates a scaled tensor based on the quantization axis.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
colwise_data: The column-wise quantized data
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
q_axis: The quantization axis (default: ROWWISE)
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_axis
"""
if
q_axis
==
QuantizeAxis
.
ROWWISE_COLWISE
:
return
ScaledTensorFactory
.
create_2x
(
data
,
scale_inv
,
colwise_data
,
colwise_scale_inv
,
scaling_mode
,
dq_dtype
,
layout
=
layout
,
)
is_colwise
=
q_axis
==
QuantizeAxis
.
COLWISE
return
ScaledTensorFactory
.
create_1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
is_colwise
=
is_colwise
,
layout
=
layout
[
0
]
)
def
with_sharding_constraint_by_logical_axes
(
x
,
logical_axis_names
:
Tuple
[
str
,
...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
x: The tensor to apply sharding constraints to
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if
isinstance
(
x
,
ScaledTensor1x
):
return
ScaledTensor1x
(
data
=
with_sharding_constraint_by_logical_axes
(
x
.
data
,
logical_axis_names
),
scale_inv
=
x
.
scale_inv
,
scaling_mode
=
x
.
scaling_mode
,
dq_dtype
=
x
.
dq_dtype
,
_dq_func
=
x
.
_dq_func
,
is_colwise
=
x
.
is_colwise
,
layout
=
x
.
layout
,
)
if
isinstance
(
x
,
ScaledTensor2x
):
return
ScaledTensor2x
(
rowwise_tensor
=
with_sharding_constraint_by_logical_axes
(
x
.
rowwise_tensor
,
logical_axis_names
),
colwise_tensor
=
with_sharding_constraint_by_logical_axes
(
x
.
colwise_tensor
,
logical_axis_names
),
)
return
original_with_sharding_constraint_by_logical_axes
(
x
,
logical_axis_names
)
transformer_engine/jax/setup.py
View file @
a207db1d
...
@@ -2,7 +2,22 @@
...
@@ -2,7 +2,22 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""Installation script for TE jax extensions."""
"""Installation script for Transformer Engine JAX extensions.
This module handles the build and installation of the JAX-specific components
of Transformer Engine. It manages:
- JAX extension compilation with pybind11
- Common header file management
- Build tool dependencies
- Package metadata and dependencies
The script supports both development and release builds, with different
behaviors for:
- Build tool management
- Header file copying
- Extension compilation
- Package distribution
"""
# pylint: disable=wrong-import-position,wrong-import-order
# pylint: disable=wrong-import-position,wrong-import-order
...
@@ -41,6 +56,34 @@ CMakeBuildExtension = get_build_ext(BuildExtension, True)
...
@@ -41,6 +56,34 @@ CMakeBuildExtension = get_build_ext(BuildExtension, True)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
"""Main entry point for JAX extension installation.
This section handles:
1. Common header file management
- Creates a temporary directory for common headers
- Copies necessary header files from the common library
2. Extension module setup
- Configures the JAX-specific C++ extension
- Sets up build paths and dependencies
3. Package configuration
- Sets package metadata
- Configures build and install requirements
- Sets up extension modules
4. Cleanup
- Removes temporary directories after build
- Cleans up build tools if not in release mode
Environment variables:
- NVTE_RELEASE_BUILD: Controls release build behavior
- NVTE_PROJECT_BUILDING: Set to "1" during build
Note:
The script requires JAX to be installed for building.
It will raise a RuntimeError if JAX is not available.
"""
# Extensions
# Extensions
common_headers_dir
=
"common_headers"
common_headers_dir
=
"common_headers"
copy_common_headers
(
current_file_path
.
parent
,
str
(
current_file_path
/
common_headers_dir
))
copy_common_headers
(
current_file_path
.
parent
,
str
(
current_file_path
/
common_headers_dir
))
...
...
transformer_engine/jax/sharding.py
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""
"""Sharding utilities for Transformer Engine in JAX.
Sharding Meta for xmap with CustomCall
This module provides utilities for managing tensor sharding in distributed training,
including support for various parallelism strategies like data parallelism (DP),
tensor parallelism (TP), pipeline parallelism (PP), and full-sharded data
parallelism (FSDP). It includes functions for sharding constraints, mesh management,
and collective operations.
"""
"""
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -181,27 +186,17 @@ def get_mesh_axis_rank(axis: str, mesh=None):
...
@@ -181,27 +186,17 @@ def get_mesh_axis_rank(axis: str, mesh=None):
@
dataclass
@
dataclass
class
MeshResource
:
class
MeshResource
:
"""
"""A data container for managing mesh resources in distributed training.
A data container to indicate which axis in Mesh for data parallelism and
which for tensor parallelism.
This class defines the mapping between logical axes and physical mesh axes
for different types of parallelism in distributed training.
Parameters
----------
Attributes:
dp_resource : str, default = None
dp_resource: Axis name for data parallelism (batch sharding), default is None
The axis name in Mesh used to shard batches along.
tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None
If it is None, then data parallelism is disabled.
fsdp_resource: Axis name for full-sharded data parallelism, default is None
tp_resource : str, default = None
pp_resource: Axis name for pipeline parallelism (layer sharding), default is None
The axis name in Mesh used to split the hidden dimensions along.
cp_resource: Axis name for context parallelism (sequence sharding), default is None
If it is None, then tensor parallelism is disabled.
fsdp_resource : str, default = None
The axis name in Mesh used to split the batch and weights along.
If it is None, then full-sharded data parallelism is disabled.
pp_resource : str, default = None
The axis name in Mesh used to split model layers along.
If it is None, then pipeline parallelism is disabled.
cp_resource : str, default = None
The axis name in Mesh used to split sequence (context) dimensions along
in the attention. If it is None, then context parallelism is disabled.
"""
"""
dp_resource
:
str
=
None
dp_resource
:
str
=
None
...
@@ -216,36 +211,55 @@ _GLOBAL_MESH_RESOURCE = MeshResource()
...
@@ -216,36 +211,55 @@ _GLOBAL_MESH_RESOURCE = MeshResource()
@
contextmanager
@
contextmanager
def
global_shard_guard
(
resource
:
MeshResource
):
def
global_shard_guard
(
resource
:
MeshResource
):
"""
"""Context manager for setting global sharding configuration.
A context manager to switch the global MeshResource
This context manager allows temporarily setting the global mesh resource
configuration for sharding operations.
Args:
resource: MeshResource instance defining the sharding configuration
"""
"""
global
_GLOBAL_MESH_RESOURCE
global
_GLOBAL_MESH_RESOURCE
prev_gmr
=
_GLOBAL_MESH_RESOURCE
old_resources
=
_GLOBAL_MESH_RESOURCE
try
:
try
:
_GLOBAL_MESH_RESOURCE
=
resource
_GLOBAL_MESH_RESOURCE
=
resource
yield
yield
finally
:
finally
:
_GLOBAL_MESH_RESOURCE
=
prev_gmr
_GLOBAL_MESH_RESOURCE
=
old_resources
def
global_mesh_resource
()
->
MeshResource
:
def
global_mesh_resource
()
->
MeshResource
:
"""
"""Get the current global mesh resource configuration.
A getter of the global MeshResource
Returns:
The current MeshResource instance
"""
"""
return
_GLOBAL_MESH_RESOURCE
return
_GLOBAL_MESH_RESOURCE
def
all_reduce_sum_along_dp_fsdp
(
x
:
jnp
.
array
,
mesh
:
jax
.
sharding
.
Mesh
):
def
all_reduce_sum_along_dp_fsdp
(
x
:
jnp
.
array
,
mesh
:
jax
.
sharding
.
Mesh
):
"""
"""Perform all-reduce sum operation along data parallelism and FSDP axes.
All-Reduce (Sum) along DP and FSDP mesh axes.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
"""
"""
x
=
lax_paral_op
(
x
,
jax
.
lax
.
psum
,
global_mesh_resource
().
dp_resource
,
mesh
)
x
=
lax_paral_op
(
x
,
jax
.
lax
.
psum
,
global_mesh_resource
().
dp_resource
,
mesh
)
return
lax_paral_op
(
x
,
jax
.
lax
.
psum
,
global_mesh_resource
().
fsdp_resource
,
mesh
)
return
lax_paral_op
(
x
,
jax
.
lax
.
psum
,
global_mesh_resource
().
fsdp_resource
,
mesh
)
def
all_reduce_max_along_all_axes_except_PP
(
x
:
jnp
.
array
,
mesh
:
jax
.
sharding
.
Mesh
):
def
all_reduce_max_along_all_axes_except_PP
(
x
:
jnp
.
array
,
mesh
:
jax
.
sharding
.
Mesh
):
"""
"""Perform all-reduce max operation along all axes except pipeline parallelism.
All-Reduce (Max) along all mesh axes.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
"""
"""
all_axes
=
get_all_mesh_axes
()
all_axes
=
get_all_mesh_axes
()
for
axis
in
all_axes
:
for
axis
in
all_axes
:
...
@@ -261,21 +275,16 @@ global_shard_resource = global_mesh_resource
...
@@ -261,21 +275,16 @@ global_shard_resource = global_mesh_resource
class
MajorShardingType
(
Enum
):
class
MajorShardingType
(
Enum
):
r
"""
"""Enumeration of major sharding types for distributed training.
The major sharding type to indicate sharding pattern.
.. warning::
This enum defines the basic sharding patterns available for distributed
MajorShardingType is deprecating in the near feature.
training. Note that this class is deprecated and will be removed in the future.
Values
Values:
----------
SINGLE: Single process training
SINGLE:
DP: Data parallel training
Single process training.
TP: Standard tensor parallel training
DP:
DPTP: Data and standard tensor parallel training
Data parallel training.
TP:
Standard tensor parallel training.
DPTP:
Data and Standard tensor parallel training.
"""
"""
SINGLE
=
0
SINGLE
=
0
...
@@ -285,25 +294,19 @@ class MajorShardingType(Enum):
...
@@ -285,25 +294,19 @@ class MajorShardingType(Enum):
class
ShardingType
(
Enum
):
class
ShardingType
(
Enum
):
"""
"""Enumeration of detailed sharding types for distributed training.
The sharding type to indicate sharding pattern.
.. warning::
This enum defines specific sharding patterns for distributed training,
ShardingType is deprecating in the near feature.
including combinations of data parallelism and different tensor parallelism
strategies. Note that this class is deprecated and will be removed in the future.
Values
----------
Values:
SINGLE:
SINGLE: No sharding
No sharding.
DP: Sharding along data parallelism
DP:
TP_COL: Sharding along column-split tensor parallelism
Sharding along data parallelism.
TP_ROW: Sharding along row-split tensor parallelism
TP_COL:
DP_TP_COL: Sharding along data and column-split tensor parallelism
Sharding along column-split tensor parallelism.
DP_TP_ROW: Sharding along data and row-split tensor parallelism
TP_ROW:
Sharding along row-split tensor parallelism.
DP_TP_COL:
Sharding along data and column-split tensor parallelism.
DP_TP_ROW:
Sharding along data and row-split tensor parallelism.
"""
"""
SINGLE
=
(
MajorShardingType
.
SINGLE
,
"single"
)
SINGLE
=
(
MajorShardingType
.
SINGLE
,
"single"
)
...
...
transformer_engine/pytorch/attention.py
View file @
a207db1d
...
@@ -690,9 +690,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -690,9 +690,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# partial result quantizer
# partial result quantizer
for
i
in
range
(
cp_size
):
for
i
in
range
(
cp_size
):
S_quantizer_per_step
[
i
]
=
S_quantizer
.
copy
()
S_quantizer_per_step
[
i
]
=
S_quantizer
.
copy
()
S_quantizer_per_step
[
i
].
amax
=
amax_per_step
[
0
][
i
]
S_quantizer_per_step
[
i
].
amax
=
amax_per_step
[
0
][
i
]
.
reshape
((
1
,))
O_CP_quantizer_per_step
[
i
]
=
O_CP_quantizer
.
copy
()
O_CP_quantizer_per_step
[
i
]
=
O_CP_quantizer
.
copy
()
O_CP_quantizer_per_step
[
i
].
amax
=
amax_per_step
[
1
][
i
]
O_CP_quantizer_per_step
[
i
].
amax
=
amax_per_step
[
1
][
i
]
.
reshape
((
1
,))
else
:
else
:
assert
False
,
"FP8 is only supported with Fused Attention!"
assert
False
,
"FP8 is only supported with Fused Attention!"
else
:
else
:
...
@@ -1361,16 +1361,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1361,16 +1361,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
i
>
1
:
if
i
>
1
:
flash_attn_streams
[(
i
-
1
)
%
2
].
wait_event
(
fwd_results_correction_done
)
flash_attn_streams
[(
i
-
1
)
%
2
].
wait_event
(
fwd_results_correction_done
)
if
use_fused_attention
:
# [b, np, sq, 1] -> [b, np, sq] or
# [t, np, 1] -> [t, np]
softmax_lse_per_step
[
i
-
1
].
squeeze_
(
-
1
)
if
softmax_lse_in_packed_format
:
softmax_lse_per_step
[
i
-
1
]
=
(
softmax_lse_per_step
[
i
-
1
].
transpose
(
0
,
1
).
contiguous
()
)
with
torch
.
cuda
.
stream
(
flash_attn_streams
[(
i
-
1
)
%
2
]):
with
torch
.
cuda
.
stream
(
flash_attn_streams
[(
i
-
1
)
%
2
]):
if
use_fused_attention
:
# [b, np, sq, 1] -> [b, np, sq] or
# [t, np, 1] -> [t, np]
softmax_lse_per_step
[
i
-
1
].
squeeze_
(
-
1
)
if
softmax_lse_in_packed_format
:
softmax_lse_per_step
[
i
-
1
]
=
(
softmax_lse_per_step
[
i
-
1
].
transpose
(
0
,
1
).
contiguous
()
)
if
fp8
:
if
fp8
:
out_per_step
[
i
-
1
]
=
out_per_step
[
i
-
1
].
dequantize
(
dtype
=
torch
.
float32
)
out_per_step
[
i
-
1
]
=
out_per_step
[
i
-
1
].
dequantize
(
dtype
=
torch
.
float32
)
if
i
==
1
:
if
i
==
1
:
...
@@ -1479,8 +1478,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1479,8 +1478,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
fp8
and
use_fused_attention
:
if
fp8
and
use_fused_attention
:
amax_cp_fwd
=
amax_per_step
.
amax
(
dim
=
1
)
amax_cp_fwd
=
amax_per_step
.
amax
(
dim
=
1
)
S_quantizer
.
amax
=
amax_cp_fwd
[
0
]
S_quantizer
.
amax
.
copy_
(
amax_cp_fwd
[
0
]
)
O_CP_quantizer
.
amax
=
amax_cp_fwd
[
1
]
O_CP_quantizer
.
amax
.
copy_
(
amax_cp_fwd
[
1
]
)
out_fp8
=
None
out_fp8
=
None
out_f16
=
out
.
to
(
qkv_dtype
)
out_f16
=
out
.
to
(
qkv_dtype
)
...
@@ -1513,16 +1512,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1513,16 +1512,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
ctx
.
tensor_objects
=
tensor_objects
ctx
.
qkv_dtype
=
qkv_dtype
ctx
.
QKV_quantizer
=
QKV_quantizer
ctx
.
O_quantizer
=
O_quantizer
ctx
.
O_CP_quantizer
=
O_CP_quantizer
ctx
.
S_quantizer
=
S_quantizer
ctx
.
dQKV_quantizer
=
dQKV_quantizer
ctx
.
dQKV_CP_quantizer
=
dQKV_CP_quantizer
ctx
.
dO_quantizer
=
dO_quantizer
ctx
.
dP_quantizer
=
dP_quantizer
ctx
.
cp_group_a2a
=
cp_group_a2a
ctx
.
cp_group_a2a
=
cp_group_a2a
ctx
.
cp_size_a2a
=
cp_size_a2a
ctx
.
cp_size_a2a
=
cp_size_a2a
ctx
.
rank_a2a
=
rank_a2a
ctx
.
rank_a2a
=
rank_a2a
...
@@ -1546,6 +1535,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1546,6 +1535,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx
.
is_input_fp8
=
is_input_fp8
ctx
.
is_input_fp8
=
is_input_fp8
ctx
.
is_output_fp8
=
is_output_fp8
ctx
.
is_output_fp8
=
is_output_fp8
ctx
.
use_flash_attn_3
=
use_flash_attn_3
ctx
.
use_flash_attn_3
=
use_flash_attn_3
ctx
.
qkv_dtype
=
qkv_dtype
ctx
.
dQKV_quantizer
=
dQKV_quantizer
ctx
.
dQKV_CP_quantizer
=
dQKV_CP_quantizer
ctx
.
dO_quantizer
=
dO_quantizer
ctx
.
dP_quantizer
=
dP_quantizer
ctx
.
QKV_quantizer
=
QKV_quantizer
ctx
.
O_quantizer
=
O_quantizer
ctx
.
S_quantizer
=
S_quantizer
if
ctx
.
fp8
:
ctx
.
QKV_quantizer
=
QKV_quantizer
.
copy
()
ctx
.
QKV_quantizer
.
scale
=
QKV_quantizer
.
scale
.
clone
()
ctx
.
O_quantizer
=
O_quantizer
.
copy
()
ctx
.
O_quantizer
.
scale
=
O_quantizer
.
scale
.
clone
()
ctx
.
S_quantizer
=
S_quantizer
.
copy
()
ctx
.
S_quantizer
.
scale
=
S_quantizer
.
scale
.
clone
()
nvtx_range_pop
(
"transformer_engine.AttnFuncWithCPAndKVP2P.forward"
)
nvtx_range_pop
(
"transformer_engine.AttnFuncWithCPAndKVP2P.forward"
)
return
out_ret
return
out_ret
...
@@ -1632,32 +1637,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1632,32 +1637,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
ctx
.
use_fused_attention
:
if
ctx
.
use_fused_attention
:
fused_attn_backend
=
FusedAttnBackend
[
"FP8"
]
fused_attn_backend
=
FusedAttnBackend
[
"FP8"
]
dqkv_fp8_torch_dtype
=
get_fp8_torch_dtype
(
ctx
.
fp8_meta
[
"recipe"
],
fprop_tensor
=
False
)
dq_fp8
=
torch
.
empty
(
(
cp_size
,
*
q
.
shape
),
dtype
=
dqkv_fp8_torch_dtype
,
device
=
q
.
device
)
dkv_fp8
=
torch
.
empty
(
(
cp_size
,
*
kv
.
shape
),
dtype
=
dqkv_fp8_torch_dtype
,
device
=
kv
.
device
)
dkv_fp8_
=
torch
.
empty_like
(
dkv_fp8
)
if
ctx
.
is_output_fp8
:
if
ctx
.
is_output_fp8
:
assert
isinstance
(
dout
,
Float8Tensor
),
"dout must be Float8Tensors for FP8 MHA!"
assert
isinstance
(
dout
,
Float8Tensor
),
"dout must be Float8Tensors for FP8 MHA!"
ctx
.
dO_quantizer
=
dout
.
_quantizer
ctx
.
dO_quantizer
=
dout
.
_quantizer
else
:
else
:
dout
=
ctx
.
dO_quantizer
(
dout
)
dout
=
ctx
.
dO_quantizer
(
dout
)
fused_attn_dqkv_dtype
=
dout
.
_fp8_dtype
fused_attn_dqkv_dtype
=
TE_DType
[
dout
.
_data
.
dtype
]
dout
=
dout
.
_data
dq_fp8
=
torch
.
empty
((
cp_size
,
*
q
.
shape
),
dtype
=
dout
.
_data
.
dtype
,
device
=
q
.
device
)
dkv_fp8
=
torch
.
empty
(
(
cp_size
,
*
kv
.
shape
),
dtype
=
dout
.
_data
.
dtype
,
device
=
kv
.
device
)
dkv_fp8_
=
torch
.
empty_like
(
dkv_fp8
)
p2p_comm_buffers
=
[[
kv
,
dkv_fp8
],
[
torch
.
empty_like
(
kv
),
dkv_fp8_
]]
p2p_comm_buffers
=
[[
kv
,
dkv_fp8
],
[
torch
.
empty_like
(
kv
),
dkv_fp8_
]]
dout
=
dout
.
_data
fp8_meta_kwargs
=
{}
fp8_meta_kwargs
=
{}
fp8_meta_kwargs
[
"s_quantizer"
]
=
ctx
.
S_quantizer
fp8_meta_kwargs
[
"s_quantizer"
]
=
ctx
.
S_quantizer
amax_per_step
=
torch
.
zeros
((
2
,
cp_size
),
dtype
=
torch
.
float32
,
device
=
q
.
device
)
amax_per_step
=
torch
.
zeros
((
2
,
cp_size
),
dtype
=
torch
.
float32
,
device
=
q
.
device
)
for
i
in
range
(
cp_size
):
for
i
in
range
(
cp_size
):
dP_quantizer_per_step
[
i
]
=
ctx
.
dP_quantizer
.
copy
()
dP_quantizer_per_step
[
i
]
=
ctx
.
dP_quantizer
.
copy
()
dP_quantizer_per_step
[
i
].
amax
=
amax_per_step
[
0
][
i
]
dP_quantizer_per_step
[
i
].
amax
=
amax_per_step
[
0
][
i
]
.
reshape
((
1
,))
dQKV_CP_quantizer_per_step
[
i
]
=
ctx
.
dQKV_CP_quantizer
.
copy
()
dQKV_CP_quantizer_per_step
[
i
]
=
ctx
.
dQKV_CP_quantizer
.
copy
()
dQKV_CP_quantizer_per_step
[
i
].
amax
=
amax_per_step
[
1
][
i
]
dQKV_CP_quantizer_per_step
[
i
].
amax
=
amax_per_step
[
1
][
i
]
.
reshape
((
1
,))
else
:
else
:
assert
False
,
"FP8 is only supported with Fused Attention!"
assert
False
,
"FP8 is only supported with Fused Attention!"
else
:
else
:
...
@@ -1838,7 +1838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1838,7 +1838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part
,
v_part
,
out_part
,
out_part
,
dout_part
,
dout_part
,
ctx
.
qkv
_dtype
,
dout
_dtype
,
fused_attn_dqkv_dtype
,
fused_attn_dqkv_dtype
,
aux_ctx_tensors
,
aux_ctx_tensors
,
fused_attn_backend
,
fused_attn_backend
,
...
@@ -1962,7 +1962,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -1962,7 +1962,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part
,
v_part
,
out_part
,
out_part
,
dout_part
,
dout_part
,
ctx
.
qkv
_dtype
,
dout
_dtype
,
fused_attn_dqkv_dtype
,
fused_attn_dqkv_dtype
,
aux_ctx_tensors
,
aux_ctx_tensors
,
fused_attn_backend
,
fused_attn_backend
,
...
@@ -2090,7 +2090,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2090,7 +2090,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part
,
v_part
,
out_part
,
out_part
,
dout_part
,
dout_part
,
ctx
.
qkv
_dtype
,
dout
_dtype
,
fused_attn_dqkv_dtype
,
fused_attn_dqkv_dtype
,
aux_ctx_tensors
,
aux_ctx_tensors
,
fused_attn_backend
,
fused_attn_backend
,
...
@@ -2195,7 +2195,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2195,7 +2195,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
v_part
,
v_part
,
out_part
,
out_part
,
dout_part
,
dout_part
,
ctx
.
qkv
_dtype
,
dout
_dtype
,
fused_attn_dqkv_dtype
,
fused_attn_dqkv_dtype
,
aux_ctx_tensors
,
aux_ctx_tensors
,
fused_attn_backend
,
fused_attn_backend
,
...
@@ -2395,8 +2395,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
...
@@ -2395,8 +2395,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if
ctx
.
fp8
and
ctx
.
use_fused_attention
:
if
ctx
.
fp8
and
ctx
.
use_fused_attention
:
amax_cp_bwd
=
amax_per_step
.
amax
(
dim
=
1
)
amax_cp_bwd
=
amax_per_step
.
amax
(
dim
=
1
)
ctx
.
dP_quantizer
.
amax
=
amax_cp_bwd
[
0
]
ctx
.
dP_quantizer
.
amax
.
copy_
(
amax_cp_bwd
[
0
]
)
ctx
.
dQKV_CP_quantizer
.
amax
=
amax_cp_bwd
[
1
]
ctx
.
dQKV_CP_quantizer
.
amax
.
copy_
(
amax_cp_bwd
[
1
]
)
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
if
ctx
.
qkv_format
in
[
"bshd"
,
"sbhd"
]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
...
@@ -3229,14 +3229,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
...
@@ -3229,14 +3229,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
ctx
.
tensor_objects
=
tensor_objects
ctx
.
qkv_dtype
=
qkv_dtype
ctx
.
QKV_quantizer
=
QKV_quantizer
ctx
.
O_quantizer
=
O_quantizer
ctx
.
S_quantizer
=
S_quantizer
ctx
.
dQKV_quantizer
=
dQKV_quantizer
ctx
.
dO_quantizer
=
dO_quantizer
ctx
.
dP_quantizer
=
dP_quantizer
ctx
.
batch_size
=
batch_size
ctx
.
batch_size
=
batch_size
ctx
.
cp_group
=
cp_group
ctx
.
cp_group
=
cp_group
ctx
.
cp_stream
=
cp_stream
ctx
.
cp_stream
=
cp_stream
...
@@ -3255,6 +3247,21 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
...
@@ -3255,6 +3247,21 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx
.
is_input_fp8
=
is_input_fp8
ctx
.
is_input_fp8
=
is_input_fp8
ctx
.
is_output_fp8
=
is_output_fp8
ctx
.
is_output_fp8
=
is_output_fp8
ctx
.
use_flash_attn_3
=
use_flash_attn_3
ctx
.
use_flash_attn_3
=
use_flash_attn_3
ctx
.
qkv_dtype
=
qkv_dtype
ctx
.
dQKV_quantizer
=
dQKV_quantizer
ctx
.
dO_quantizer
=
dO_quantizer
ctx
.
dP_quantizer
=
dP_quantizer
ctx
.
QKV_quantizer
=
QKV_quantizer
ctx
.
O_quantizer
=
O_quantizer
ctx
.
S_quantizer
=
S_quantizer
if
ctx
.
fp8
:
ctx
.
QKV_quantizer
=
QKV_quantizer
.
copy
()
ctx
.
QKV_quantizer
.
scale
=
QKV_quantizer
.
scale
.
clone
()
ctx
.
O_quantizer
=
O_quantizer
.
copy
()
ctx
.
O_quantizer
.
scale
=
O_quantizer
.
scale
.
clone
()
ctx
.
S_quantizer
=
S_quantizer
.
copy
()
ctx
.
S_quantizer
.
scale
=
S_quantizer
.
scale
.
clone
()
nvtx_range_pop
(
"transformer_engine.AttnFuncWithCPAndQKVOA2A.forward"
)
nvtx_range_pop
(
"transformer_engine.AttnFuncWithCPAndQKVOA2A.forward"
)
return
out_ret
return
out_ret
...
@@ -3291,7 +3298,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
...
@@ -3291,7 +3298,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx
.
dO_quantizer
=
dout
.
_quantizer
ctx
.
dO_quantizer
=
dout
.
_quantizer
else
:
else
:
dout
=
ctx
.
dO_quantizer
(
dout
)
dout
=
ctx
.
dO_quantizer
(
dout
)
fused_attn_dqkv_dtype
=
dout
.
_
fp8_
dtype
fused_attn_dqkv_dtype
=
TE_DType
[
dout
.
_
data
.
dtype
]
dout
=
dout
.
_data
dout
=
dout
.
_data
fp8_meta_kwargs
=
{}
fp8_meta_kwargs
=
{}
fp8_meta_kwargs
[
"s_quantizer"
]
=
ctx
.
S_quantizer
fp8_meta_kwargs
[
"s_quantizer"
]
=
ctx
.
S_quantizer
...
@@ -3401,7 +3408,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
...
@@ -3401,7 +3408,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
v_part
,
v_part
,
out_part
,
out_part
,
dout_part
,
dout_part
,
ctx
.
qkv
_dtype
,
dout
_dtype
,
fused_attn_dqkv_dtype
,
fused_attn_dqkv_dtype
,
aux_ctx_tensors
,
aux_ctx_tensors
,
fused_attn_backend
,
fused_attn_backend
,
...
@@ -4748,6 +4755,9 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -4748,6 +4755,9 @@ class FusedAttnFunc(torch.autograd.Function):
ctx
.
dO_quantizer
=
dO_quantizer
ctx
.
dO_quantizer
=
dO_quantizer
ctx
.
dP_quantizer
=
dP_quantizer
ctx
.
dP_quantizer
=
dP_quantizer
ctx
.
S_quantizer
=
S_quantizer
ctx
.
S_quantizer
=
S_quantizer
if
ctx
.
fp8
:
ctx
.
S_quantizer
=
S_quantizer
.
copy
()
ctx
.
S_quantizer
.
scale
=
S_quantizer
.
scale
.
clone
()
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_kv
=
max_seqlen_kv
ctx
.
max_seqlen_kv
=
max_seqlen_kv
...
@@ -4963,8 +4973,6 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -4963,8 +4973,6 @@ class FusedAttnFunc(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
)
# else, return (dqkv, dbias)
# else, return (dqkv, dbias)
return
(
return
(
...
@@ -4995,8 +5003,6 @@ class FusedAttnFunc(torch.autograd.Function):
...
@@ -4995,8 +5003,6 @@ class FusedAttnFunc(torch.autograd.Function):
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
)
)
...
@@ -5126,6 +5132,16 @@ class FusedAttention(torch.nn.Module):
...
@@ -5126,6 +5132,16 @@ class FusedAttention(torch.nn.Module):
# get q_format and kv_format for training and inference
# get q_format and kv_format for training and inference
qkv_format
,
q_format
,
kv_format
=
dpa_utils
.
get_qkv_format
(
qkv_layout
,
inference_params
)
qkv_format
,
q_format
,
kv_format
=
dpa_utils
.
get_qkv_format
(
qkv_layout
,
inference_params
)
# cuDNN can work with 0-length sequences in the batch for both bshd/sbhd and thd formats
# however, for bshd/sbhd, q/k/v tensors need to have the same batch size as indicated by
# cu_seqlens, whereas thd does not have this requirement
# e.g. if q_format = bshd, and q.shape = [3, 1, 16, 64], we should have k.shape[0] =
# v.shape[0] = q.shape[0], and cu_seqlens_q.shape = cu_seqlens_kv.shape = [4]
if
q_format
in
[
"bshd"
,
"sbhd"
]
or
kv_format
in
[
"bshd"
,
"sbhd"
]:
batch_size
=
query_layer
.
shape
[
0
]
if
q_format
==
"bshd"
else
query_layer
.
shape
[
1
]
cu_seqlens_q
=
cu_seqlens_q
[:
batch_size
+
1
]
cu_seqlens_kv
=
cu_seqlens_kv
[:
batch_size
+
1
]
page_table
=
None
page_table
=
None
if
inference_params
is
None
:
if
inference_params
is
None
:
if
qkv_format
in
[
"sbhd"
,
"bshd"
]:
if
qkv_format
in
[
"sbhd"
,
"bshd"
]:
...
@@ -6209,7 +6225,11 @@ class DotProductAttention(TransformerEngineBaseModule):
...
@@ -6209,7 +6225,11 @@ class DotProductAttention(TransformerEngineBaseModule):
# raise exception if no backend is available
# raise exception if no backend is available
if
sum
([
use_flash_attention
,
use_fused_attention
,
use_unfused_attention
])
==
0
:
if
sum
([
use_flash_attention
,
use_fused_attention
,
use_unfused_attention
])
==
0
:
raise
ValueError
(
"No dot product attention support for the provided inputs!"
)
raise
ValueError
(
"No dot product attention backend is available for the provided inputs. Please"
" run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for"
" disabling all backends."
)
# run attention
# run attention
if
use_flash_attention
:
if
use_flash_attention
:
...
...
transformer_engine/pytorch/csrc/common.h
View file @
a207db1d
...
@@ -153,7 +153,6 @@ class Float8CurrentScalingQuantizer : public Quantizer {
...
@@ -153,7 +153,6 @@ class Float8CurrentScalingQuantizer : public Quantizer {
DType
dtype
;
DType
dtype
;
bool
with_amax_reduction
;
bool
with_amax_reduction
;
c10
::
intrusive_ptr
<
dist_group_type
>
amax_reduction_group
;
c10
::
intrusive_ptr
<
dist_group_type
>
amax_reduction_group
;
int
amax_reduction_size
;
bool
force_pow_2_scales
=
false
;
bool
force_pow_2_scales
=
false
;
float
amax_epsilon
=
0.0
;
float
amax_epsilon
=
0.0
;
...
...
transformer_engine/pytorch/csrc/extensions/quantizer.cpp
View file @
a207db1d
...
@@ -145,24 +145,21 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q
...
@@ -145,24 +145,21 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q
const
at
::
Tensor
&
scale
=
quantizer
.
attr
(
"scale"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
scale
=
quantizer
.
attr
(
"scale"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
amax
=
quantizer
.
attr
(
"amax"
).
cast
<
at
::
Tensor
>
();
const
at
::
Tensor
&
amax
=
quantizer
.
attr
(
"amax"
).
cast
<
at
::
Tensor
>
();
const
DType
type
=
quantizer
.
attr
(
"dtype"
).
cast
<
DType
>
();
const
DType
type
=
quantizer
.
attr
(
"dtype"
).
cast
<
DType
>
();
// For current scaling, need several other components:
// 1. with_amax_reduction: bool
// 2. amax_reduction_group: torch.distributed.ProcessGroup or None
// 3. amax_reduction_size: int
const
bool
with_amax_reduction
=
quantizer
.
attr
(
"with_amax_reduction"
).
cast
<
bool
>
();
const
py
::
object
amax_reduction_group_obj
=
quantizer
.
attr
(
"amax_reduction_group"
);
const
c10
::
intrusive_ptr
<
dist_group_type
>
amax_reduction_group
=
amax_reduction_group_obj
.
is_none
()
?
nullptr
:
amax_reduction_group_obj
.
cast
<
c10
::
intrusive_ptr
<
dist_group_type
>>
();
const
int
amax_reduction_size
=
quantizer
.
attr
(
"amax_reduction_size"
).
cast
<
int
>
();
this
->
amax
=
amax
;
this
->
amax
=
amax
;
this
->
scale
=
scale
;
this
->
scale
=
scale
;
this
->
dtype
=
type
;
this
->
dtype
=
type
;
// Get amax reduction group if needed
const
bool
with_amax_reduction
=
quantizer
.
attr
(
"with_amax_reduction"
).
cast
<
bool
>
();
c10
::
intrusive_ptr
<
dist_group_type
>
amax_reduction_group
;
if
(
with_amax_reduction
)
{
auto
group
=
quantizer
.
attr
(
"_canonicalized_amax_reduction_group"
)();
NVTE_CHECK
(
!
group
.
is_none
(),
"Float8CurrentScalingQuantizer could not canonicalize amax reduction group"
);
amax_reduction_group
=
group
.
cast
<
c10
::
intrusive_ptr
<
dist_group_type
>>
();
}
this
->
with_amax_reduction
=
with_amax_reduction
;
this
->
with_amax_reduction
=
with_amax_reduction
;
this
->
amax_reduction_group
=
amax_reduction_group
;
this
->
amax_reduction_group
=
amax_reduction_group
;
this
->
amax_reduction_size
=
amax_reduction_size
;
// fp8 current scaling specific quantization params
// fp8 current scaling specific quantization params
this
->
force_pow_2_scales
=
quantizer
.
attr
(
"force_pow_2_scales"
).
cast
<
bool
>
();
this
->
force_pow_2_scales
=
quantizer
.
attr
(
"force_pow_2_scales"
).
cast
<
bool
>
();
...
...
transformer_engine/pytorch/distributed.py
View file @
a207db1d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
"""Methods needed for distributed training (DP/TP)."""
"""Methods needed for distributed training (DP/TP)."""
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
contextlib
import
contextmanager
,
AbstractContextManager
,
ContextDecorator
from
contextlib
import
contextmanager
,
AbstractContextManager
,
ContextDecorator
from
functools
import
lru_cache
from
functools
import
lru_cache
import
math
import
math
...
@@ -876,10 +877,14 @@ def _all_gather_fp8(
...
@@ -876,10 +877,14 @@ def _all_gather_fp8(
# we cannot directly gather the transposed fp8 tensor
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
# and then set it back to the original value after quantizing
init_rowwise_usage
=
quantizer
.
rowwise_usage
init_columnwise_usage
=
quantizer
.
columnwise_usage
init_columnwise_usage
=
quantizer
.
columnwise_usage
quantizer
.
set_usage
(
columnwise
=
False
)
quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
inp
=
quantizer
(
inp
)
inp
=
quantizer
(
inp
)
quantizer
.
set_usage
(
columnwise
=
init_columnwise_usage
)
quantizer
.
set_usage
(
rowwise
=
init_rowwise_usage
,
columnwise
=
init_columnwise_usage
,
)
# Construct output tensor
# Construct output tensor
out
:
Float8TensorBase
out
:
Float8TensorBase
...
@@ -936,9 +941,34 @@ def _all_gather_mxfp8(
...
@@ -936,9 +941,34 @@ def _all_gather_mxfp8(
)
->
tuple
[
MXFP8TensorBase
,
Optional
[
torch
.
distributed
.
Work
]]:
)
->
tuple
[
MXFP8TensorBase
,
Optional
[
torch
.
distributed
.
Work
]]:
"""All-gather MXFP8 tensor along first dimension."""
"""All-gather MXFP8 tensor along first dimension."""
# Tensor dims
# Input tensor attributes
in_shape
:
Iterable
[
int
]
device
:
torch
.
device
dtype
:
torch
.
dtype
if
isinstance
(
inp
,
torch
.
Tensor
):
in_shape
=
inp
.
size
()
device
=
inp
.
device
dtype
=
inp
.
dtype
elif
isinstance
(
inp
,
MXFP8TensorBase
):
if
inp
.
_rowwise_data
is
not
None
:
in_shape
=
inp
.
_rowwise_data
.
device
.
size
()
device
=
inp
.
_rowwise_data
.
device
dtype
=
inp
.
_rowwise_data
.
dtype
elif
inp
.
_columnwise_data
is
not
None
:
in_shape
=
inp
.
_columnwise_data
.
device
.
size
()
device
=
inp
.
_columnwise_data
.
device
dtype
=
inp
.
_columnwise_data
.
dtype
else
:
raise
ValueError
(
"Got MXFP8 input tensor without any data"
)
dtype
=
torch
.
bfloat16
else
:
raise
ValueError
(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
f
"found
{
inp
.
__class__
.
__name__
}
)"
)
# Output tensor shape
world_size
=
get_distributed_world_size
(
process_group
)
world_size
=
get_distributed_world_size
(
process_group
)
in_shape
=
list
(
inp
.
size
())
if
out_shape
is
None
:
if
out_shape
is
None
:
out_shape
=
[
in_shape
[
0
]
*
world_size
]
+
in_shape
[
1
:]
out_shape
=
[
in_shape
[
0
]
*
world_size
]
+
in_shape
[
1
:]
...
@@ -951,25 +981,19 @@ def _all_gather_mxfp8(
...
@@ -951,25 +981,19 @@ def _all_gather_mxfp8(
):
):
out
=
torch
.
empty
(
out
=
torch
.
empty
(
out_shape
,
out_shape
,
dtype
=
inp
.
dtype
,
dtype
=
dtype
,
device
=
inp
.
device
,
device
=
device
,
memory_format
=
torch
.
contiguous_format
,
memory_format
=
torch
.
contiguous_format
,
)
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
)
torch
.
distributed
.
all_gather_into_tensor
(
out
,
inp
,
group
=
process_group
)
out
=
quantizer
(
out
)
out
=
quantizer
(
out
)
return
out
,
None
return
out
,
None
inp_dtype
=
inp
.
dtype
inp_device
=
inp
.
device
# Cast input tensor to MXFP8 with required data
# Cast input tensor to MXFP8 with required data
if
not
isinstance
(
inp
,
MXFP8TensorBase
):
if
not
isinstance
(
inp
,
MXFP8TensorBase
):
inp
=
quantizer
(
inp
)
inp
=
quantizer
(
inp
)
elif
(
elif
(
quantizer
.
rowwise_usage
and
inp
.
_rowwise_data
is
None
)
or
(
inp
.
rowwise_data
is
None
quantizer
.
columnwise_usage
and
inp
.
_columnwise_data
is
None
and
quantizer
.
rowwise_usage
or
inp
.
columnwise_data
is
None
and
quantizer
.
columnwise_usage
):
):
warnings
.
warn
(
warnings
.
warn
(
"Input and quantizer do not have matching usages. "
"Input and quantizer do not have matching usages. "
...
@@ -978,65 +1002,64 @@ def _all_gather_mxfp8(
...
@@ -978,65 +1002,64 @@ def _all_gather_mxfp8(
inp
=
quantizer
(
inp
.
dequantize
())
inp
=
quantizer
(
inp
.
dequantize
())
# Construct MXFP8 output tensor
# Construct MXFP8 output tensor
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
inp_dtype
,
device
=
inp_device
)
out
=
quantizer
.
make_empty
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
# Async op handle
handle
=
None
# Gather MXFP8 data for row-wise usage
if
quantizer
.
rowwise_usage
:
# Remove padding from MXFP8 scale-inverses
in_scale_inv
=
inp
.
_rowwise_scale_inv
out_scale_inv
=
out
.
_rowwise_scale_inv
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
out_scale_inv
[
flattened_in_shape0
*
world_size
:].
zero_
()
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
# Launch all-gathers
if
handle
is
not
None
:
handle
.
wait
()
torch
.
distributed
.
all_gather_into_tensor
(
out_scale_inv
,
in_scale_inv
,
group
=
process_group
,
)
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_rowwise_data
,
inp
.
_rowwise_data
,
group
=
process_group
,
async_op
=
async_op
,
)
# Gather MXFP8 data for column-wise usage
if
quantizer
.
columnwise_usage
:
# Remove padding from MXFP8 scale-inverses
# Coalesce NCCL collectives
in_scale_inv
=
inp
.
_columnwise_scale_inv
with
torch
.
distributed
.
_coalescing_manager
(
out_scale_inv
=
out
.
_columnwise_scale_inv
group
=
process_group
,
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
//
32
device
=
device
,
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
async_ops
=
async_op
,
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
)
as
coalescing_manager
:
out_scale_inv
[
flattened_in_shape0
*
world_size
:].
zero_
()
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
# Gather MXFP8 data for row-wise usage
if
quantizer
.
rowwise_usage
:
# Remove padding from MXFP8 scale-inverses
in_scale_inv
=
inp
.
_rowwise_scale_inv
out_scale_inv
=
out
.
_rowwise_scale_inv
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
out_scale_inv
[
flattened_in_shape0
*
world_size
:].
zero_
()
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
# Launch all-gathers
torch
.
distributed
.
all_gather_into_tensor
(
out_scale_inv
,
in_scale_inv
,
group
=
process_group
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_rowwise_data
,
inp
.
_rowwise_data
,
group
=
process_group
,
)
# Launch all-gathers
# Gather MXFP8 data for column-wise usage
if
handle
is
not
None
:
if
quantizer
.
columnwise_usage
:
handle
.
wait
()
torch
.
distributed
.
all_gather_into_tensor
(
# Remove padding from MXFP8 scale-inverses
out_scale_inv
,
in_scale_inv
=
inp
.
_columnwise_scale_inv
in_scale_inv
,
out_scale_inv
=
out
.
_columnwise_scale_inv
group
=
process_group
,
flattened_in_shape0
=
math
.
prod
(
in_shape
[:
-
1
])
//
32
)
if
in_scale_inv
.
size
(
0
)
!=
flattened_in_shape0
:
handle
=
torch
.
distributed
.
all_gather_into_tensor
(
in_scale_inv
=
in_scale_inv
[:
flattened_in_shape0
]
out
.
_columnwise_data
,
out_scale_inv
[
flattened_in_shape0
*
world_size
:].
zero_
()
inp
.
_columnwise_data
,
out_scale_inv
=
out_scale_inv
[:
flattened_in_shape0
*
world_size
]
group
=
process_group
,
async_op
=
async_op
,
# Launch all-gathers
)
torch
.
distributed
.
all_gather_into_tensor
(
out_scale_inv
,
in_scale_inv
,
group
=
process_group
,
)
torch
.
distributed
.
all_gather_into_tensor
(
out
.
_columnwise_data
,
inp
.
_columnwise_data
,
group
=
process_group
,
)
handle
=
coalescing_manager
if
async_op
else
None
return
out
,
handle
return
out
,
handle
...
...
transformer_engine/pytorch/dot_product_attention/inference.py
View file @
a207db1d
...
@@ -100,7 +100,7 @@ class InferenceParams:
...
@@ -100,7 +100,7 @@ class InferenceParams:
----------
----------
max_batch_size: int
max_batch_size: int
Maximum batch size in inference
Maximum batch size in inference
max_seq
l
en
_kv
: int
max_seq
u
en
ce_length
: int
Maximum sequence length in inference
Maximum sequence length in inference
num_heads_kv: int
num_heads_kv: int
Number of attention heads in keys and values
Number of attention heads in keys and values
...
@@ -117,7 +117,7 @@ class InferenceParams:
...
@@ -117,7 +117,7 @@ class InferenceParams:
page_size: int, default = None
page_size: int, default = None
Page size of the KV cache. Required for is_paged = True.
Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None
max_ctx_len: int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_seq
l
en
_kv
.
Maximum context length in inference. 1 <= max_ctx_len <= max_seq
u
en
ce_length
.
qkv_format: str, default = "bshd"
qkv_format: str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration
Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None
custom_cache_manager: KVCacheManager, default = None
...
@@ -127,7 +127,7 @@ class InferenceParams:
...
@@ -127,7 +127,7 @@ class InferenceParams:
def
__init__
(
def
__init__
(
self
,
self
,
max_batch_size
:
int
,
max_batch_size
:
int
,
max_seq
l
en
_kv
:
int
,
max_seq
u
en
ce_length
:
int
,
num_heads_kv
:
int
=
16
,
num_heads_kv
:
int
=
16
,
head_dim_k
:
int
=
64
,
head_dim_k
:
int
=
64
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
...
@@ -140,7 +140,7 @@ class InferenceParams:
...
@@ -140,7 +140,7 @@ class InferenceParams:
custom_cache_manager
:
KVCacheManager
=
None
,
custom_cache_manager
:
KVCacheManager
=
None
,
):
):
self
.
max_batch_size
=
max_batch_size
self
.
max_batch_size
=
max_batch_size
self
.
max_seq
l
en
_kv
=
max_seq
l
en
_kv
self
.
max_seq
u
en
ce_length
=
max_seq
u
en
ce_length
self
.
num_heads_kv
=
num_heads_kv
self
.
num_heads_kv
=
num_heads_kv
self
.
head_dim_k
=
head_dim_k
self
.
head_dim_k
=
head_dim_k
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -153,7 +153,7 @@ class InferenceParams:
...
@@ -153,7 +153,7 @@ class InferenceParams:
)
)
self
.
cache_manager
=
cache_manager
(
self
.
cache_manager
=
cache_manager
(
max_batch_size
=
self
.
max_batch_size
,
max_batch_size
=
self
.
max_batch_size
,
max_seqlen
=
self
.
max_seq
l
en
_kv
,
max_seqlen
=
self
.
max_seq
u
en
ce_length
,
num_heads
=
self
.
num_heads_kv
,
num_heads
=
self
.
num_heads_kv
,
head_dim_k
=
self
.
head_dim_k
,
head_dim_k
=
self
.
head_dim_k
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -163,9 +163,9 @@ class InferenceParams:
...
@@ -163,9 +163,9 @@ class InferenceParams:
assert
page_size
is
not
None
,
"Paged KV cache requires page_size is not None."
assert
page_size
is
not
None
,
"Paged KV cache requires page_size is not None."
self
.
page_size
=
page_size
self
.
page_size
=
page_size
assert
(
assert
(
max_seq
l
en
_kv
%
page_size
==
0
max_seq
u
en
ce_length
%
page_size
==
0
),
"Paged KV cache requires max_seq
l
en
_kv
% page_size = 0."
),
"Paged KV cache requires max_seq
u
en
ce_length
% page_size = 0."
max_pages_per_seq
=
max_seq
l
en
_kv
//
page_size
max_pages_per_seq
=
max_seq
u
en
ce_length
//
page_size
assert
(
assert
(
total_num_pages
==
self
.
max_batch_size
*
max_pages_per_seq
total_num_pages
==
self
.
max_batch_size
*
max_pages_per_seq
),
"Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq."
),
"Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq."
...
@@ -181,7 +181,7 @@ class InferenceParams:
...
@@ -181,7 +181,7 @@ class InferenceParams:
head_dim_k
=
self
.
head_dim_k
,
head_dim_k
=
self
.
head_dim_k
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
max_batch_size
=
self
.
max_batch_size
,
max_batch_size
=
self
.
max_batch_size
,
max_seqlen
=
self
.
max_seq
l
en
_kv
,
max_seqlen
=
self
.
max_seq
u
en
ce_length
,
head_dim_v
=
self
.
head_dim_v
,
head_dim_v
=
self
.
head_dim_v
,
)
)
...
@@ -231,7 +231,7 @@ class InferenceParams:
...
@@ -231,7 +231,7 @@ class InferenceParams:
f
"dtype=
{
self
.
dtype
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"is_paged=
{
self
.
is_paged
}
, "
f
"is_paged=
{
self
.
is_paged
}
, "
f
"max_batch_size=
{
self
.
max_batch_size
}
, "
f
"max_batch_size=
{
self
.
max_batch_size
}
, "
f
"max_seqlen=
{
self
.
max_seq
l
en
_kv
}
, "
f
"max_seqlen=
{
self
.
max_seq
u
en
ce_length
}
, "
f
"num_heads=
{
self
.
num_heads_kv
}
, "
f
"num_heads=
{
self
.
num_heads_kv
}
, "
f
"head_dim_k=
{
self
.
head_dim_k
}
, "
f
"head_dim_k=
{
self
.
head_dim_k
}
, "
f
"head_dim_v=
{
self
.
head_dim_v
}
"
f
"head_dim_v=
{
self
.
head_dim_v
}
"
...
@@ -241,8 +241,8 @@ class InferenceParams:
...
@@ -241,8 +241,8 @@ class InferenceParams:
"""
"""
Allocate memory for the cache. For layer layer_number,
Allocate memory for the cache. For layer layer_number,
- NonPagedKVCacheManager:
- NonPagedKVCacheManager:
- K cache: [max_batch_size, max_seq
l
en
_kv
, num_heads_kv, head_dim_k]
- K cache: [max_batch_size, max_seq
u
en
ce_length
, num_heads_kv, head_dim_k]
- V cache: [max_batch_size, max_seq
l
en
_kv
, num_heads_kv, head_dim_v]
- V cache: [max_batch_size, max_seq
u
en
ce_length
, num_heads_kv, head_dim_v]
- PagedKVCacheManager:
- PagedKVCacheManager:
- K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k]
- K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k]
- V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v]
- V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v]
...
@@ -348,7 +348,7 @@ class InferenceParams:
...
@@ -348,7 +348,7 @@ class InferenceParams:
Updated cumulative sequence lengths for key and value, [batch_size + 1]
Updated cumulative sequence lengths for key and value, [batch_size + 1]
max_seqlen_q: int
max_seqlen_q: int
Update maximum sequence length for query
Update maximum sequence length for query
max_seq
l
en
_kv
: int
max_seq
u
en
ce_length
: int
Update maximum sequence length for key and value
Update maximum sequence length for key and value
qkv_format: str
qkv_format: str
Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step()
Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step()
...
@@ -373,7 +373,7 @@ class InferenceParams:
...
@@ -373,7 +373,7 @@ class InferenceParams:
v_cache
,
v_cache
,
self
.
cu_seqlens_q
,
self
.
cu_seqlens_q
,
self
.
cu_seqlens_kv
,
self
.
cu_seqlens_kv
,
self
.
max_seq
l
en
_kv
,
self
.
max_seq
u
en
ce_length
,
self
.
output_qkv_format
,
self
.
output_qkv_format
,
)
)
...
...
transformer_engine/pytorch/jit.py
View file @
a207db1d
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
"""NVFuser functions and JIT utilities"""
"""NVFuser functions and JIT utilities"""
import
os
import
os
from
functools
import
wraps
from
typing
import
Callable
,
Optional
,
Tuple
from
typing
import
Callable
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -11,15 +12,34 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
...
@@ -11,15 +12,34 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
# pylint: disable=unnecessary-lambda-assignment
# pylint: disable=unnecessary-lambda-assignment
def
lazy_compile
(
func
):
"""Lazy compile a function with torch.compile
This decorator defers the compilation of a function until the first call, speeding up the
overall module's import time if these functions are not used.
"""
compiled_func
=
None
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
compiled_func
if
compiled_func
is
None
:
compiled_func
=
torch
.
compile
(
func
)
return
compiled_func
(
*
args
,
**
kwargs
)
return
wrapper
jit_fuser
=
lambda
func
:
func
jit_fuser
=
lambda
func
:
func
if
torch
.
__version__
>=
"2"
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
if
torch
.
__version__
>=
"2"
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
jit_fuser
=
torch
.
compile
jit_fuser
=
lazy_
compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser
=
torch
.
jit
.
script
dropout_fuser
=
torch
.
jit
.
script
if
torch
.
__version__
>=
"2.2"
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
if
torch
.
__version__
>=
"2.2"
and
bool
(
int
(
os
.
getenv
(
"NVTE_TORCH_COMPILE"
,
"1"
))):
dropout_fuser
=
torch
.
compile
dropout_fuser
=
lazy_
compile
# Decorator to disable Torch Dynamo
# Decorator to disable Torch Dynamo
...
...
transformer_engine/pytorch/module/base.py
View file @
a207db1d
...
@@ -1018,6 +1018,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -1018,6 +1018,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out
=
None
out
=
None
if
cache_name
is
not
None
:
if
cache_name
is
not
None
:
out
=
self
.
_fp8_workspaces
.
get
(
cache_name
,
None
)
out
=
self
.
_fp8_workspaces
.
get
(
cache_name
,
None
)
if
quantizer
is
not
None
and
isinstance
(
out
,
MXFP8TensorBase
):
if
quantizer
.
rowwise_usage
and
out
.
_rowwise_data
is
None
:
out
=
None
del
self
.
_fp8_workspaces
[
cache_name
]
elif
quantizer
.
columnwise_usage
and
out
.
_columnwise_data
is
None
:
out
=
None
del
self
.
_fp8_workspaces
[
cache_name
]
# Gather cached Fp8 workspace if it's distributed
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
...
...
transformer_engine/pytorch/module/grouped_linear.py
View file @
a207db1d
...
@@ -78,8 +78,8 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -78,8 +78,8 @@ class _GroupedLinear(torch.autograd.Function):
skip_fp8_weight_update
,
skip_fp8_weight_update
,
*
weights_and_biases
,
*
weights_and_biases
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring
num_gemms
=
len
(
m_splits
)
num_gemms
=
len
(
m_splits
)
weights
=
weights_and_biases
[:
num_gemms
]
weights
=
weights_and_biases
[:
num_gemms
]
biases
=
weights_and_biases
[
num_gemms
:]
biases
=
weights_and_biases
[
num_gemms
:]
...
@@ -180,7 +180,12 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -180,7 +180,12 @@ class _GroupedLinear(torch.autograd.Function):
ctx
.
weights_shape_1
=
weights
[
0
].
shape
[
1
]
ctx
.
weights_shape_1
=
weights
[
0
].
shape
[
1
]
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
inputmats
,
*
weights_fp8
,
*
biases
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
*
inputmats
,
*
weights_fp8
,
*
weights
,
*
biases
,
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
ctx
.
tensor_objects
=
tensor_objects
...
@@ -220,7 +225,8 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -220,7 +225,8 @@ class _GroupedLinear(torch.autograd.Function):
N
=
ctx
.
num_gemms
N
=
ctx
.
num_gemms
inputmats
=
saved_tensors
[:
N
]
inputmats
=
saved_tensors
[:
N
]
weights
=
saved_tensors
[
N
:
2
*
N
]
weights
=
saved_tensors
[
N
:
2
*
N
]
biases
=
saved_tensors
[
2
*
N
:
3
*
N
]
origin_weights
=
saved_tensors
[
2
*
N
:
3
*
N
]
biases
=
saved_tensors
[
3
*
N
:
4
*
N
]
main_grads
=
ctx
.
main_grads
main_grads
=
ctx
.
main_grads
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
# TOSO
if
ctx
.
cpu_offloading
and
ctx
.
fuse_wgrad_accumulation
:
# TOSO
...
@@ -311,21 +317,24 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -311,21 +317,24 @@ class _GroupedLinear(torch.autograd.Function):
# Deallocate input tensor
# Deallocate input tensor
clear_tensor_data
(
*
inputmats
)
clear_tensor_data
(
*
inputmats
)
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
def
handle_custom_ddp_from_mcore
(
w
eight
,
wgrad
):
if
ctx
.
weights_requires_grad
:
if
ctx
.
weights_requires_grad
:
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
# Handle custom DDP from mcore.
w
.
grad_added_to_main_grad
=
True
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
weight
,
"grad_added_to_main_grad"
):
weight
.
grad_added_to_main_grad
=
True
if
getattr
(
weight
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
w
eight
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
dtype
=
w
eight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
else
:
else
:
wgrad
=
torch
.
empty
(
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
w
eight
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
dtype
=
w
eight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
...
@@ -336,7 +345,8 @@ class _GroupedLinear(torch.autograd.Function):
...
@@ -336,7 +345,8 @@ class _GroupedLinear(torch.autograd.Function):
return
wgrad
return
wgrad
wgrad_list
=
[
wgrad_list
=
[
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
handle_custom_ddp_from_mcore
(
weight
,
wgrad
)
for
weight
,
wgrad
in
zip
(
origin_weights
,
wgrad_list
)
]
]
else
:
else
:
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
a207db1d
...
@@ -55,7 +55,6 @@ from ..tensor.quantized_tensor import (
...
@@ -55,7 +55,6 @@ from ..tensor.quantized_tensor import (
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
)
)
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
...
@@ -137,6 +136,11 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -137,6 +136,11 @@ class _LayerNormLinear(torch.autograd.Function):
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm_input_cast"
)
# Avoid quantized norm kernel if norm output will be returned
with_quantized_norm
=
(
fp8
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
ub_overlap_ag_fprop
=
(
ub_overlap_ag_fprop
=
(
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
ub_overlap_ag_fprop
and
is_grad_enabled
and
not
return_layernorm_output
...
@@ -146,6 +150,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -146,6 +150,7 @@ class _LayerNormLinear(torch.autograd.Function):
backward_needs_input
=
is_grad_enabled
and
weight_requires_grad
backward_needs_input
=
is_grad_enabled
and
weight_requires_grad
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
with_input_all_gather
=
parallel_mode
==
"column"
and
sequence_parallel
# Check if Userbuffers is supported
if
fp8
:
if
fp8
:
if
any
([
ub_overlap_ag_fprop
,
ub_overlap_rs_fprop
])
and
not
(
if
any
([
ub_overlap_ag_fprop
,
ub_overlap_rs_fprop
])
and
not
(
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
...
@@ -155,104 +160,74 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -155,104 +160,74 @@ class _LayerNormLinear(torch.autograd.Function):
" current scaling"
" current scaling"
)
)
# Configure quantizer for norm output
if
fp8
:
if
input_quantizer
is
None
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
raise
ValueError
(
"Missing quantizer for input tensor"
)
columnwise_usage
=
backward_needs_input
# Configure quantizer for normalization output
if
(
with_quantized_norm
=
fp8
and
not
return_layernorm_output
columnwise_usage
if
with_quantized_norm
:
and
with_input_all_gather
if
with_input_all_gather
:
and
not
isinstance
(
input_quantizer
,
MXFP8Quantizer
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
):
if
isinstance
(
input_quantizer
,
MXFP8Quantizer
):
columnwise_usage
=
False
with_quantized_norm
=
False
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
else
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
,
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if
(
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
and
ub_bulk_dgrad
):
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ub_obj_fprop
=
None
ln_out
=
None
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if
ub_overlap_ag_fprop
and
not
isinstance
(
input_quantizer
,
Float8CurrentScalingQuantizer
):
ub_obj_fprop
=
get_ub
(
ub_name
+
"_fprop"
)
ln_out
=
ub_obj_fprop
.
get_buffer
(
input_quantizer
,
local_chunk
=
True
)
elif
with_quantized_norm
:
if
with_input_all_gather
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out
=
input_quantizer
.
make_empty
(
inputmat
.
shape
,
dtype
=
inputmat
.
dtype
,
device
=
"cuda"
)
else
:
ln_out
=
torch
.
empty_like
(
inputmat
,
dtype
=
inputmat
.
dtype
,
memory_format
=
torch
.
contiguous_format
,
device
=
"cuda"
)
# Apply normalization
# Apply normalization
nvtx_range_push
(
f
"
{
nvtx_label
}
.norm"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.norm"
)
ln_out
,
mu
,
rsigma
=
apply_normalization
(
ln_out
,
mu
,
rsigma
=
apply_normalization
(
inputmat
,
inputmat
,
ln_out
,
None
,
#
ln_out
ln_weight
,
ln_weight
,
ln_bias
,
ln_bias
,
eps
,
eps
,
input_quantizer
if
with_quantized_norm
else
None
,
input_quantizer
if
with_quantized_norm
else
None
,
inp
.
dtype
,
inp
utmat
.
dtype
,
normalization
,
normalization
,
fwd_ln_sm_margin
,
fwd_ln_sm_margin
,
zero_centered_gamma
,
zero_centered_gamma
,
)
)
ln_out_return
=
ln_out
if
return_layernorm_output
else
None
ln_out_return
=
None
if
return_layernorm_output
or
return_layernorm_output_gathered
:
ln_out_return
=
ln_out
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.norm"
)
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if
ub_overlap_ag_fprop
and
isinstance
(
input_quantizer
,
Float8CurrentScalingQuantizer
):
ub_obj_fprop
=
get_ub
(
ub_name
+
"_fprop"
)
ln_out_local
=
ln_out
ln_out
=
ub_obj_fprop
.
get_buffer
(
input_quantizer
,
local_chunk
=
True
)
input_quantizer
.
quantize
(
ln_out_local
,
out
=
ln_out
)
# Prepare GEMM input
# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
if
with_input_all_gather
and
not
ub_overlap_ag_fprop
:
ln_out_total
=
None
with_quantized_all_gather
=
fp8
ub_obj_fprop
=
None
if
return_layernorm_output
and
return_layernorm_output_gathered
:
if
with_input_all_gather
:
with_quantized_all_gather
=
False
if
return_layernorm_output_gathered
:
if
fp8
:
# Perform all-gather in high precision if gathered
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
# norm output will be returned
# ln_out in this has two possibilities:
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel
# 2. in high precision, then we need to cast it and then gather in FP8
# the output ln_out_total will be in FP8, and it's a full tensor
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
,
quantizer
=
(
input_quantizer
if
with_quantized_all_gather
else
None
),
)
if
return_layernorm_output
and
return_layernorm_output_gathered
:
ln_out_return
=
ln_out_total
ln_out_return
=
ln_out_total
if
fp8
and
not
with_quantized_all_gather
:
if
fp8
:
ln_out_total
=
input_quantizer
(
ln_out_total
)
ln_out
=
input_quantizer
(
ln_out
)
else
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
ub_overlap_ag_fprop
:
ln_out_total
=
input_quantizer
(
ln_out_total
)
ln_out_total
=
ub_obj_fprop
.
get_buffer
(
input_quantizer
)
else
:
else
:
if
fp8
:
if
fp8
:
if
not
isinstance
(
ln_out
,
QuantizedTensor
):
if
not
with_quantized_norm
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backward_needs_input
)
ln_out
=
input_quantizer
(
ln_out
)
ln_out
=
input_quantizer
(
ln_out
)
elif
backward_needs_input
:
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
if
ub_overlap_ag_fprop
:
ln_out_total
=
ln_out
# Copy into Userbuffers buffer
ub_obj_fprop
=
get_ub
(
ub_name
+
"_fprop"
)
ub_obj_fprop
.
get_buffer
(
input_quantizer
,
local_chunk
=
True
).
copy_
(
ln_out
)
ln_out_total
=
ub_obj_fprop
.
get_buffer
(
input_quantizer
)
else
:
# All-gather with NCCL
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
,
quantizer
=
(
input_quantizer
if
fp8
else
None
),
)
else
:
if
fp8
and
not
with_quantized_norm
:
ln_out
=
input_quantizer
(
ln_out
)
ln_out_total
=
ln_out
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
nvtx_range_pop
(
f
"
{
nvtx_label
}
.gemm_input_cast_comm"
)
# Cast weight to expected dtype
# Cast weight to expected dtype
...
@@ -341,7 +316,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -341,7 +316,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight
.
requires_grad
and
parallel_mode
==
"column"
and
sequence_parallel
weight
.
requires_grad
and
parallel_mode
==
"column"
and
sequence_parallel
)
)
# Input with column-wise usage is needed for
d
grad GEMM.
# Input with column-wise usage is needed for
w
grad GEMM.
if
backward_needs_input
:
if
backward_needs_input
:
if
isinstance
(
ln_out
,
QuantizedTensor
):
if
isinstance
(
ln_out
,
QuantizedTensor
):
# For sequence parallel in vanilla FP8, rowwise data is
# For sequence parallel in vanilla FP8, rowwise data is
...
@@ -350,6 +325,11 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -350,6 +325,11 @@ class _LayerNormLinear(torch.autograd.Function):
if
isinstance
(
ln_out
,
MXFP8TensorBase
)
or
not
ctx
.
ln_out_needs_gather
:
if
isinstance
(
ln_out
,
MXFP8TensorBase
)
or
not
ctx
.
ln_out_needs_gather
:
ln_out
.
update_usage
(
rowwise_usage
=
False
)
ln_out
.
update_usage
(
rowwise_usage
=
False
)
# Weight with column-wise usage is needed for dgrad GEMM.
if
inp
.
requires_grad
:
if
isinstance
(
weightmat
,
QuantizedTensor
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
if
cpu_offloading
:
if
fp8
and
weightmat
is
not
None
:
if
fp8
and
weightmat
is
not
None
:
set_offloading_param
(
weightmat
,
"weight_offloading"
,
True
)
set_offloading_param
(
weightmat
,
"weight_offloading"
,
True
)
...
@@ -392,7 +372,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -392,7 +372,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight
,
weight
,
bias
,
bias
,
ln_weight
,
ln_weight
,
ln_out
.
clone
()
if
ub_overlap_ag_fprop
else
ln_out
,
# avoid saving a UB buffer
ln_out
,
mu
,
mu
,
rsigma
,
rsigma
,
)
)
...
@@ -603,7 +583,7 @@ class _LayerNormLinear(torch.autograd.Function):
...
@@ -603,7 +583,7 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer
=
None
quantizer
=
None
if
ctx
.
fp8
:
if
ctx
.
fp8
:
quantizer
=
ctx
.
input_quantizer
quantizer
=
ctx
.
input_quantizer
quantizer
.
set_usage
(
rowwise
=
Tru
e
,
columnwise
=
True
)
quantizer
.
set_usage
(
rowwise
=
Fals
e
,
columnwise
=
True
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out
,
ln_out
,
...
@@ -1436,9 +1416,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
...
@@ -1436,9 +1416,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_fwd"
][
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_size
=
self
.
tp_size
else
:
else
:
# set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here)
# set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here)
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
a207db1d
...
@@ -61,7 +61,6 @@ from ..tensor.float8_tensor import Float8Tensor
...
@@ -61,7 +61,6 @@ from ..tensor.float8_tensor import Float8Tensor
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
._common
import
apply_normalization
,
_fix_gathered_fp8_transpose
from
._common
import
apply_normalization
,
_fix_gathered_fp8_transpose
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
..tensor.quantized_tensor
import
(
from
..tensor.quantized_tensor
import
(
QuantizedTensor
,
QuantizedTensor
,
Quantizer
,
Quantizer
,
...
@@ -208,112 +207,81 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -208,112 +207,81 @@ class _LayerNormMLP(torch.autograd.Function):
if
ln_bias
is
not
None
:
if
ln_bias
is
not
None
:
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
ln_bias
=
cast_if_needed
(
ln_bias
,
activation_dtype
)
# for fp8 DelayedScaling: layernorm output = FP8
# Avoid quantized norm kernel if norm output will be returned
# only output of the linear is returned
with_quantized_norm
=
(
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
fp8
and
not
return_layernorm_output
and
not
return_layernorm_output_gathered
# high precision layernorm output and output of the linear are returned
)
with_quantized_norm
=
fp8
and
not
return_layernorm_output
tp_world_size
=
get_distributed_world_size
(
tp_group
)
tp_world_size
=
get_distributed_world_size
(
tp_group
)
ub_overlap_ag
=
ub_overlap_ag
and
is_grad_enabled
and
not
return_layernorm_output
ub_overlap_ag
=
ub_overlap_ag
and
is_grad_enabled
and
not
return_layernorm_output
_gathered
ub_overlap_rs
=
ub_overlap_rs
and
is_grad_enabled
ub_overlap_rs
=
ub_overlap_rs
and
is_grad_enabled
with_input_all_gather_nccl
=
sequence_parallel
and
not
ub_overlap_ag
backwards_needs_fc1_input
=
is_grad_enabled
and
fc1_weight
.
requires_grad
backwards_needs_fc1_input
=
is_grad_enabled
and
fc1_weight
.
requires_grad
# Configure quantizer for normalization output
# Configure quantizer for norm output
if
fp8
and
fc1_input_quantizer
is
None
:
if
fp8
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
fc1_input_quantizer
is
None
:
if
with_quantized_norm
:
raise
ValueError
(
"Missing quantizer for FC1 input tensor"
)
if
with_input_all_gather_nccl
:
columnwise_usage
=
backwards_needs_fc1_input
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
if
(
if
isinstance
(
fc1_input_quantizer
,
MXFP8Quantizer
):
columnwise_usage
with_quantized_norm
=
False
and
sequence_parallel
else
:
and
not
isinstance
(
fc1_input_quantizer
,
MXFP8Quantizer
)
fc1_input_quantizer
.
set_usage
(
):
rowwise
=
True
,
columnwise_usage
=
False
columnwise
=
backwards_needs_fc1_input
,
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
)
# Reduce duplicated transpose in `_fix_gathered_fp8_transpose`
if
(
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
()
and
ub_bulk_dgrad
):
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ub_obj_lnout
=
None
ln_out
=
None
# For DelayScaling, output of normalization will be in fp8.
# For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8.
if
ub_overlap_ag
and
not
isinstance
(
fc1_input_quantizer
,
Float8CurrentScalingQuantizer
):
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
ln_out
=
ub_obj_lnout
.
get_buffer
(
fc1_input_quantizer
,
local_chunk
=
True
)
elif
not
with_quantized_norm
:
ln_out
=
torch
.
empty_like
(
inputmat
,
dtype
=
inputmat
.
dtype
,
memory_format
=
torch
.
contiguous_format
,
device
=
"cuda"
)
# Apply normalization
# Apply normalization
ln_out
,
mu
,
rsigma
=
apply_normalization
(
ln_out
,
mu
,
rsigma
=
apply_normalization
(
inputmat
,
inputmat
,
ln_out
,
None
,
#
ln_out
ln_weight
,
ln_weight
,
ln_bias
,
ln_bias
,
eps
,
eps
,
fc1_input_quantizer
if
with_quantized_norm
else
None
,
fc1_input_quantizer
if
with_quantized_norm
else
None
,
inp
.
dtype
,
inp
utmat
.
dtype
,
normalization
,
normalization
,
fwd_ln_sm_margin
,
fwd_ln_sm_margin
,
zero_centered_gamma
,
zero_centered_gamma
,
)
)
ln_out_return
=
None
ln_out_return
=
ln_out
if
return_layernorm_output
else
None
if
return_layernorm_output
or
return_layernorm_output_gathered
:
ln_out_return
=
ln_out
# For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer.
# So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer.
if
ub_overlap_ag
and
isinstance
(
fc1_input_quantizer
,
Float8CurrentScalingQuantizer
):
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
ln_out_local
=
ln_out
ln_out
=
ub_obj_lnout
.
get_buffer
(
fc1_input_quantizer
,
local_chunk
=
True
)
fc1_input_quantizer
.
quantize
(
ln_out_local
,
out
=
ln_out
)
# Prepare GEMM input
# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
# Note: Cast to expected dtype and perform tensor-parallel communication
ln_out_gathered
=
False
ln_out_total
=
None
with_quantized_all_gather
=
fp8
ub_obj_lnout
=
None
if
with_input_all_gather_nccl
:
if
sequence_parallel
:
if
return_layernorm_output
and
return_layernorm_output_gathered
:
if
return_layernorm_output_gathered
:
with_quantized_all_gather
=
False
# Perform all-gather in high precision if gathered
if
fp8
:
# norm output will be returned
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
)
# ln_out in this has two possibilities:
ln_out_return
=
ln_out_total
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel
if
fp8
:
# 2. in high precision, then we need to cast it and then gather in FP8
ln_out
=
fc1_input_quantizer
(
ln_out
)
# the output ln_out_total will be in FP8, and it's a full tensor
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out_total
=
fc1_input_quantizer
(
ln_out_total
)
ln_out
,
tp_group
,
quantizer
=
(
fc1_input_quantizer
if
with_quantized_all_gather
else
None
),
)
ln_out_gathered
=
True
else
:
with_quantized_all_gather
=
False
if
ub_overlap_ag
:
ln_out_total
=
ub_obj_lnout
.
get_buffer
(
fc1_input_quantizer
,
False
)
else
:
else
:
if
fp8
:
if
fp8
:
if
not
isinstance
(
ln_out
,
QuantizedTensor
):
if
not
with_quantized_norm
:
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
backwards_needs_fc1_input
)
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out
=
fc1_input_quantizer
(
ln_out
)
elif
backwards_needs_fc1_input
:
fc1_input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
ln_out
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
True
)
if
ub_overlap_ag
:
# here ln_out is in FP8 low precision, the cast was either done by fc1_input_quantizer
# Copy into Userbuffers buffer
# or fused into the layernorm kernel
ub_obj_lnout
=
get_ub
(
"fc1_fprop"
)
# ln_out_total represents the full fp8 tensor, in this case, it's the same as ln_out
ub_obj_lnout
.
get_buffer
(
fc1_input_quantizer
,
local_chunk
=
True
).
copy_
(
ln_out
)
ln_out_total
=
ln_out
ln_out_total
=
ub_obj_lnout
.
get_buffer
(
fc1_input_quantizer
)
else
:
# All-gather with NCCL
ln_out_total
,
_
=
gather_along_first_dim
(
ln_out
,
tp_group
,
quantizer
=
(
fc1_input_quantizer
if
fp8
else
None
),
)
else
:
if
fp8
and
not
with_quantized_norm
:
ln_out
=
fc1_input_quantizer
(
ln_out
)
ln_out_total
=
ln_out
# Cast weights to expected dtype
# Cast weights to expected dtype
if
not
fp8
:
if
not
fp8
:
...
@@ -423,7 +391,6 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -423,7 +391,6 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size
[
0
]
=
dim_size
[
0
]
//
tp_world_size
dim_size
[
0
]
=
dim_size
[
0
]
//
tp_world_size
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
rs_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
device
)
rs_out
=
torch
.
empty
(
dim_size
,
dtype
=
activation_dtype
,
device
=
device
)
fc2_out
=
ub_obj_fc2out
.
get_buffer
(
output_quantizer
)
else
:
else
:
dim_size
=
list
(
act_out
.
size
())
dim_size
=
list
(
act_out
.
size
())
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
dim_size
[
1
]
=
fc2_weight
.
size
(
0
)
...
@@ -443,6 +410,14 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -443,6 +410,14 @@ class _LayerNormMLP(torch.autograd.Function):
ub_type
=
tex
.
CommOverlapType
.
RS
if
ub_overlap_rs
else
None
,
ub_type
=
tex
.
CommOverlapType
.
RS
if
ub_overlap_rs
else
None
,
extra_output
=
rs_out
,
extra_output
=
rs_out
,
)
)
# Weight with column-wise usage is needed for dgrad GEMM.
if
is_grad_enabled
and
inp
.
requires_grad
:
if
isinstance
(
fc1_weight_final
,
QuantizedTensor
):
fc1_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
isinstance
(
fc2_weight_final
,
QuantizedTensor
):
fc2_weight_final
.
update_usage
(
columnwise_usage
=
True
)
if
not
is_grad_enabled
:
if
not
is_grad_enabled
:
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
clear_tensor_data
(
act_out
,
fc1_out_without_bias
,
fc1_out
)
else
:
else
:
...
@@ -490,13 +465,15 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -490,13 +465,15 @@ class _LayerNormMLP(torch.autograd.Function):
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
inputmat
,
inputmat
,
ln_weight
,
ln_weight
,
ln_out
.
clone
()
if
ub_overlap_ag
else
ln_out
,
# avoid saving a UB buffer
ln_out
,
fc1_weight_final
,
fc1_weight_final
,
fc1_weight
,
fc1_bias
,
fc1_bias
,
fc1_out
,
fc1_out
,
fc1_out_without_bias
,
fc1_out_without_bias
,
act_out
,
act_out
,
fc2_weight_final
,
fc2_weight_final
,
fc2_weight
,
fc2_bias
,
fc2_bias
,
mu
,
mu
,
rsigma
,
rsigma
,
...
@@ -537,7 +514,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -537,7 +514,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx
.
bias_gelu_fusion
=
bias_gelu_fusion
ctx
.
bias_gelu_fusion
=
bias_gelu_fusion
ctx
.
return_layernorm_output
=
return_layernorm_output
ctx
.
return_layernorm_output
=
return_layernorm_output
ctx
.
return_layernorm_output_gathered
=
(
ctx
.
return_layernorm_output_gathered
=
(
return_layernorm_output_gathered
and
ln_out_gathered
return_layernorm_output_gathered
and
sequence_parallel
)
)
ctx
.
set_parallel_mode
=
set_parallel_mode
ctx
.
set_parallel_mode
=
set_parallel_mode
ctx
.
bwd_ln_sm_margin
=
bwd_ln_sm_margin
ctx
.
bwd_ln_sm_margin
=
bwd_ln_sm_margin
...
@@ -609,11 +586,13 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -609,11 +586,13 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight
,
ln_weight
,
ln_out
,
ln_out
,
fc1_weight
,
fc1_weight
,
origin_fc1_weight
,
fc1_bias
,
fc1_bias
,
fc1_out
,
fc1_out
,
fc1_out_without_bias
,
fc1_out_without_bias
,
act_out
,
act_out
,
fc2_weight
,
fc2_weight
,
origin_fc2_weight
,
fc2_bias
,
fc2_bias
,
mu
,
mu
,
rsigma
,
rsigma
,
...
@@ -632,7 +611,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -632,7 +611,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
)
fc2_weight_main_grad
=
(
fc2_weight_main_grad
=
(
ctx
.
fc2_main_grad
ctx
.
fc2_main_grad
if
fc2_weight
is
not
None
if
origin_
fc2_weight
is
not
None
and
ctx
.
fuse_wgrad_accumulation
and
ctx
.
fuse_wgrad_accumulation
and
ctx
.
fc2_weight_requires_grad
and
ctx
.
fc2_weight_requires_grad
else
None
else
None
...
@@ -641,8 +620,8 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -641,8 +620,8 @@ class _LayerNormMLP(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
# we need to connect them into one.
if
ctx
.
fuse_wgrad_accumulation
:
if
ctx
.
fuse_wgrad_accumulation
:
fc1_weight
.
main_grad
=
fc1_weight_main_grad
origin_
fc1_weight
.
main_grad
=
fc1_weight_main_grad
fc2_weight
.
main_grad
=
fc2_weight_main_grad
origin_
fc2_weight
.
main_grad
=
fc2_weight_main_grad
# TODO: Fix this # pylint: disable=fixme
# TODO: Fix this # pylint: disable=fixme
# Gather saved autograd context tensors when running with FSDP
# Gather saved autograd context tensors when running with FSDP
...
@@ -697,7 +676,7 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -697,7 +676,7 @@ class _LayerNormMLP(torch.autograd.Function):
quantizer
=
None
quantizer
=
None
if
ctx
.
fp8
:
if
ctx
.
fp8
:
quantizer
=
ctx
.
fc1_input_quantizer
quantizer
=
ctx
.
fc1_input_quantizer
quantizer
.
set_usage
(
rowwise
=
Tru
e
,
columnwise
=
True
)
quantizer
.
set_usage
(
rowwise
=
Fals
e
,
columnwise
=
True
)
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out_total
,
ln_out_total_work
=
gather_along_first_dim
(
ln_out
,
ln_out
,
ctx
.
tp_group
,
ctx
.
tp_group
,
...
@@ -759,14 +738,18 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -759,14 +738,18 @@ class _LayerNormMLP(torch.autograd.Function):
act_out
,
act_out
,
grad_output
,
grad_output
,
get_workspace
(),
get_workspace
(),
out_dtype
=
ctx
.
activation_dtype
,
out_dtype
=
(
origin_fc2_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
quantization_params
=
None
,
# wgrad in high precision
quantization_params
=
None
,
# wgrad in high precision
layout
=
"NT"
,
layout
=
"NT"
,
grad
=
True
,
grad
=
True
,
bias
=
fc2_bias
if
fc2_bias_grad
is
None
else
None
,
bias
=
fc2_bias
if
fc2_bias_grad
is
None
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
out
=
fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
out
=
origin_
fc2_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
)
)
if
fc2_bias_grad
is
None
:
if
fc2_bias_grad
is
None
:
fc2_bias_grad
=
fc2_bias_grad_
fc2_bias_grad
=
fc2_bias_grad_
...
@@ -919,12 +902,16 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -919,12 +902,16 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total
,
ln_out_total
,
dact
,
dact
,
get_workspace
(),
get_workspace
(),
out_dtype
=
ctx
.
activation_dtype
,
out_dtype
=
(
origin_fc1_weight
.
main_grad
.
dtype
if
ctx
.
fuse_wgrad_accumulation
else
ctx
.
activation_dtype
),
layout
=
"NT"
,
layout
=
"NT"
,
grad
=
fuse_gemm_and_bias_fc1_wgrad
,
grad
=
fuse_gemm_and_bias_fc1_wgrad
,
bias
=
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
bias
=
fc1_bias
if
fuse_gemm_and_bias_fc1_wgrad
else
None
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
out
=
fc1_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
out
=
origin_
fc1_weight
.
main_grad
if
ctx
.
fuse_wgrad_accumulation
else
None
,
ub
=
ub_obj_fc1_wgrad
,
ub
=
ub_obj_fc1_wgrad
,
ub_type
=
tex
.
CommOverlapType
.
RS
if
ctx
.
ub_bulk_wgrad
else
None
,
ub_type
=
tex
.
CommOverlapType
.
RS
if
ctx
.
ub_bulk_wgrad
else
None
,
extra_output
=
fc1_dgrad_rs_out
,
extra_output
=
fc1_dgrad_rs_out
,
...
@@ -985,16 +972,21 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -985,16 +972,21 @@ class _LayerNormMLP(torch.autograd.Function):
if
ctx
.
fc1_weight_requires_grad
:
if
ctx
.
fc1_weight_requires_grad
:
# Handle custom DDP from mcore.
# Handle custom DDP from mcore.
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
fc1_weight
,
"grad_added_to_main_grad"
):
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
fc1_weight
,
"grad_added_to_main_grad"
):
fc1_weight
.
grad_added_to_main_grad
=
True
origin_
fc1_weight
.
grad_added_to_main_grad
=
True
if
getattr
(
fc1_weight
,
"zero_out_wgrad"
,
False
):
if
getattr
(
origin_
fc1_weight
,
"zero_out_wgrad"
,
False
):
fc1_wgrad
=
torch
.
zeros
(
fc1_wgrad
=
torch
.
zeros
(
fc1_weight
.
main_grad
.
shape
,
origin_
fc1_weight
.
main_grad
.
shape
,
dtype
=
fc1_weight
.
dtype
,
dtype
=
origin_
fc1_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
else
:
else
:
fc1_wgrad
=
None
fc1_wgrad
=
torch
.
empty
(
origin_fc1_weight
.
main_grad
.
shape
,
dtype
=
origin_fc1_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
elif
ctx
.
fuse_wgrad_accumulation
:
fc1_wgrad
=
None
fc1_wgrad
=
None
else
:
else
:
...
@@ -1002,17 +994,24 @@ class _LayerNormMLP(torch.autograd.Function):
...
@@ -1002,17 +994,24 @@ class _LayerNormMLP(torch.autograd.Function):
if
ctx
.
fc2_weight_requires_grad
:
if
ctx
.
fc2_weight_requires_grad
:
# Handle custom DDP from mcore.
# Handle custom DDP from mcore.
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
fc2_weight
,
"grad_added_to_main_grad"
):
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
fc2_weight
.
grad_added_to_main_grad
=
True
origin_fc2_weight
,
"grad_added_to_main_grad"
if
getattr
(
fc2_weight
,
"zero_out_wgrad"
,
False
):
):
origin_fc2_weight
.
grad_added_to_main_grad
=
True
if
getattr
(
origin_fc2_weight
,
"zero_out_wgrad"
,
False
):
fc2_wgrad
=
torch
.
zeros
(
fc2_wgrad
=
torch
.
zeros
(
fc2_weight
.
main_grad
.
shape
,
origin_
fc2_weight
.
main_grad
.
shape
,
dtype
=
fc2_weight
.
dtype
,
dtype
=
origin_
fc2_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
else
:
else
:
fc2_wgrad
=
None
fc2_wgrad
=
torch
.
empty
(
origin_fc2_weight
.
main_grad
.
shape
,
dtype
=
origin_fc2_weight
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
elif
ctx
.
fuse_wgrad_accumulation
:
fc2_wgrad
=
None
fc2_wgrad
=
None
else
:
else
:
...
@@ -1602,9 +1601,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1602,9 +1601,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_fwd"
][
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_size
=
self
.
tp_size
else
:
else
:
# grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer
# grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
...
@@ -1628,6 +1624,3 @@ class LayerNormMLP(TransformerEngineBaseModule):
...
@@ -1628,6 +1624,3 @@ class LayerNormMLP(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_size
=
self
.
tp_size
transformer_engine/pytorch/module/linear.py
View file @
a207db1d
...
@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
...
@@ -56,6 +56,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
)
)
from
..tensor.mxfp8_tensor
import
MXFP8Quantizer
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..tensor._internal.mxfp8_tensor_base
import
MXFP8TensorBase
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
from
..cpu_offload
import
is_cpu_offload_enabled
,
set_offloading_param
...
@@ -140,9 +141,13 @@ class _Linear(torch.autograd.Function):
...
@@ -140,9 +141,13 @@ class _Linear(torch.autograd.Function):
if
input_quantizer
is
None
:
if
input_quantizer
is
None
:
raise
ValueError
(
"Missing quantizer for input tensor"
)
raise
ValueError
(
"Missing quantizer for input tensor"
)
if
with_input_all_gather_nccl
:
if
with_input_all_gather_nccl
:
assert
not
isinstance
(
if
not
isinstance
(
inputmat
,
QuantizedTensor
):
inputmat
,
QuantizedTensor
columnwise_usage
=
backward_needs_input
and
isinstance
(
),
"All gather of fp8 input is not supported"
input_quantizer
,
MXFP8Quantizer
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_usage
)
inputmat
=
input_quantizer
(
inputmat
)
own_quantized_input
=
True
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
input_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
False
)
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat_total
,
_
=
gather_along_first_dim
(
inputmat
,
inputmat
,
...
@@ -269,9 +274,14 @@ class _Linear(torch.autograd.Function):
...
@@ -269,9 +274,14 @@ class _Linear(torch.autograd.Function):
# to gather the input. For MXFP8, columnwise only data
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
# can be allgathered.
if
isinstance
(
inputmat
,
MXFP8TensorBase
)
or
not
ctx
.
backward_input_needs_gather
:
if
isinstance
(
inputmat
,
MXFP8TensorBase
)
or
not
ctx
.
backward_input_needs_gather
:
inputmat
.
update_usage
(
rowwise_usage
=
False
)
inputmat
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
saved_inputmat
=
inputmat
saved_inputmat
=
inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
if
inp
.
requires_grad
:
if
isinstance
(
weightmat
,
QuantizedTensor
):
weightmat
.
update_usage
(
columnwise_usage
=
True
)
if
cpu_offloading
:
if
cpu_offloading
:
set_offloading_param
(
weight
,
"weight_offloading"
,
True
)
set_offloading_param
(
weight
,
"weight_offloading"
,
True
)
set_offloading_param
(
weightmat
,
"weight_offloading"
,
True
)
set_offloading_param
(
weightmat
,
"weight_offloading"
,
True
)
...
@@ -489,7 +499,7 @@ class _Linear(torch.autograd.Function):
...
@@ -489,7 +499,7 @@ class _Linear(torch.autograd.Function):
quantizer
=
None
quantizer
=
None
if
ctx
.
fp8
:
if
ctx
.
fp8
:
quantizer
=
ctx
.
input_quantizer
quantizer
=
ctx
.
input_quantizer
quantizer
.
set_usage
(
rowwise
=
Tru
e
,
columnwise
=
True
)
quantizer
.
set_usage
(
rowwise
=
Fals
e
,
columnwise
=
True
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
nvtx_range_push
(
f
"
{
nvtx_label
}
.column_parallel_comm_input"
)
inputmat_total
,
inputmat_total_work
=
gather_along_first_dim
(
inputmat_total
,
inputmat_total_work
=
gather_along_first_dim
(
inputmat
,
inputmat
,
...
@@ -1211,9 +1221,6 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1211,9 +1221,6 @@ class Linear(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_fwd"
][
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
self
.
quantizers
[
"scaling_fwd"
][
tex
.
FP8FwdTensors
.
GEMM1_INPUT
].
amax_reduction_size
=
self
.
tp_size
else
:
else
:
# set grad_output_quantizer with amax epsilon and power_2_scale
# set grad_output_quantizer with amax epsilon and power_2_scale
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
...
@@ -1231,6 +1238,3 @@ class Linear(TransformerEngineBaseModule):
...
@@ -1231,6 +1238,3 @@ class Linear(TransformerEngineBaseModule):
self
.
quantizers
[
"scaling_bwd"
][
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_group
=
self
.
tp_group
].
amax_reduction_group
=
self
.
tp_group
self
.
quantizers
[
"scaling_bwd"
][
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
].
amax_reduction_size
=
self
.
tp_size
transformer_engine/pytorch/ops/op.py
View file @
a207db1d
...
@@ -283,7 +283,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
...
@@ -283,7 +283,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
recipe_state
=
fp8_meta
[
fp8_meta_key
]
recipe_state
=
fp8_meta
[
fp8_meta_key
]
# Reallocate amax history if needed
# Reallocate amax history if needed
if
recipe
.
mxfp8
():
if
not
recipe
.
delayed
():
continue
continue
current_length
=
recipe_state
.
amax_history
.
size
(
0
)
current_length
=
recipe_state
.
amax_history
.
size
(
0
)
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
a207db1d
...
@@ -11,7 +11,7 @@ import torch
...
@@ -11,7 +11,7 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
DType
as
TE_DType
from
..utils
import
devices_match
,
non_tn_fp8_gemm_supported
from
..utils
import
canonicalize_process_group
,
devices_match
,
non_tn_fp8_gemm_supported
from
._internal.float8_tensor_base
import
Float8TensorBase
,
_FromFloat8Func
from
._internal.float8_tensor_base
import
Float8TensorBase
,
_FromFloat8Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
..constants
import
dist_group_type
from
..constants
import
dist_group_type
...
@@ -194,7 +194,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
...
@@ -194,7 +194,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""amax reduction options"""
"""amax reduction options"""
with_amax_reduction
:
bool
with_amax_reduction
:
bool
amax_reduction_group
:
Optional
[
dist_group_type
]
amax_reduction_group
:
Optional
[
dist_group_type
]
amax_reduction_size
:
Optional
[
int
]
"""Options about how to quantize the tensor"""
"""Options about how to quantize the tensor"""
force_pow_2_scales
:
bool
force_pow_2_scales
:
bool
amax_epsilon
:
float
amax_epsilon
:
float
...
@@ -208,7 +207,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
...
@@ -208,7 +207,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
columnwise
:
bool
=
True
,
columnwise
:
bool
=
True
,
with_amax_reduction
:
bool
=
False
,
with_amax_reduction
:
bool
=
False
,
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
amax_reduction_group
:
Optional
[
dist_group_type
]
=
None
,
amax_reduction_size
:
Optional
[
int
]
=
1
,
force_pow_2_scales
:
bool
=
False
,
force_pow_2_scales
:
bool
=
False
,
amax_epsilon
:
float
=
0.0
,
amax_epsilon
:
float
=
0.0
,
)
->
None
:
)
->
None
:
...
@@ -218,7 +216,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
...
@@ -218,7 +216,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
self
.
dtype
=
fp8_dtype
self
.
dtype
=
fp8_dtype
self
.
with_amax_reduction
=
with_amax_reduction
self
.
with_amax_reduction
=
with_amax_reduction
self
.
amax_reduction_group
=
amax_reduction_group
self
.
amax_reduction_group
=
amax_reduction_group
self
.
amax_reduction_size
=
amax_reduction_size
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
amax_epsilon
=
amax_epsilon
self
.
amax_epsilon
=
amax_epsilon
...
@@ -327,6 +324,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
...
@@ -327,6 +324,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer
=
self
,
quantizer
=
self
,
)
)
def
_canonicalized_amax_reduction_group
(
self
)
->
dist_group_type
:
"""Get process group for amax reduction"""
return
canonicalize_process_group
(
self
.
amax_reduction_group
)
class
Float8Tensor
(
Float8TensorBase
,
QuantizedTensor
):
class
Float8Tensor
(
Float8TensorBase
,
QuantizedTensor
):
"""Experimental tensor class with FP8 data
"""Experimental tensor class with FP8 data
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment