Unverified Commit cf9a7c2f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Refactor + MXFP8 + GroupedGEMM (#1627)



* refactor + mxfp8

* added grouped gemm

* rename linear to dense

* added cublas init phase for groupedGemm

* relax the tol of test encoder multiprocessing mxfp8 by 0.001
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent be055eb0
...@@ -577,3 +577,11 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -577,3 +577,11 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
} }
} }
namespace transformer_engine {
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }
} // namespace transformer_engine
...@@ -119,6 +119,13 @@ namespace transformer_engine { ...@@ -119,6 +119,13 @@ namespace transformer_engine {
constexpr int num_streams = 4; constexpr int num_streams = 4;
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
* region. This function is a helper to call cublasCreate() which allocate memory for the handle.
* The function will be called in the initialize phase of the related XLA custom calls.
*/
void nvte_cublas_handle_init();
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_H_ #endif // TRANSFORMER_ENGINE_GEMM_H_
...@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor ...@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void nvte_enable_cudnn_norm_fwd(bool enable); void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable); void nvte_enable_cudnn_norm_bwd(bool enable);
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -80,7 +80,8 @@ enum NVTEScalingMode { ...@@ -80,7 +80,8 @@ enum NVTEScalingMode {
/*! Single scale per block of 32 elements consecutive in either /*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */ rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING = 1, NVTE_MXFP8_1D_SCALING = 1,
NVTE_INVALID_SCALING NVTE_INVALID_SCALING = 2,
NVTE_NO_SCALING = 3
}; };
/*! \brief TE Tensor type /*! \brief TE Tensor type
...@@ -346,6 +347,13 @@ enum class DType { ...@@ -346,6 +347,13 @@ enum class DType {
kNumTypes kNumTypes
}; };
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
*/
bool is_fp8_dtype(const DType t);
/*! \struct TensorWrapper /*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class. * \brief C++ wrapper for the NVTETensor class.
*/ */
......
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
transformer_engine::ubuf_built_with_mpi*; transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*; *transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*; transformer_engine::nvte_cudnn_handle_init*;
transformer_engine::nvte_cublas_handle_init*;
transformer_engine::typeToSize*; transformer_engine::typeToSize*;
transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*; *transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore* *transformer_engine::CommOverlapCore*
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <cudnn.h> #include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h> #include <cudnn_frontend_utils.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <functional> #include <functional>
...@@ -137,7 +138,6 @@ struct BackwardKernelParams : public KernelParamsBase { ...@@ -137,7 +138,6 @@ struct BackwardKernelParams : public KernelParamsBase {
}; };
enum class NVTE_Norm_Backend { Te, Cudnn }; enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
enum class NVTE_Norm_Stage { Forward, Backward }; enum class NVTE_Norm_Stage { Forward, Backward };
using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>; using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>;
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX.
This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.
The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training
All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
import sys
import logging import logging
import importlib import importlib
import importlib.util import importlib.util
import ctypes
from importlib.metadata import version from importlib.metadata import version
import sys
from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
...@@ -41,7 +55,7 @@ def _load_library(): ...@@ -41,7 +55,7 @@ def _load_library():
if is_package_installed("transformer-engine-cu12"): if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name): if not is_package_installed(module_name):
_logger.info( logging.info(
"Could not find package %s. Install transformer-engine using " "Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'", "'pip3 install transformer-engine[jax]==VERSION'",
module_name, module_name,
...@@ -67,8 +81,10 @@ def _load_library(): ...@@ -67,8 +81,10 @@ def _load_library():
_load_library() _load_library()
from . import flax from . import flax
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling from . import quantize
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .quantize import fp8_autocast
from .sharding import MeshResource from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType from .sharding import MajorShardingType, ShardingResource, ShardingType
...@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper( ...@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper(
) )
__all__ = [ __all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast", "fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource", "MeshResource",
"MajorShardingType", "MajorShardingType",
"ShardingResource", "ShardingResource",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.
This module provides optimized activation functions with quantization support.
"""
from typing import Sequence, Union, Callable, Optional
from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize.tensor import ScaledTensor
from .quantize.quantizer import Quantizer
def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor.
It supports string-based activation types (e.g., 'relu', 'gelu', ('gelu', 'linear')).
Args:
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
"""Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for
custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
return _output
def _activation_fwd_rule(x, activation_type, quantizer):
"""Forward pass rule for activation function.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
if isinstance(fwd_output, ScaledTensor):
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g):
"""Backward pass rule for activation function.
Args:
activation_type: Sequence of activation functions
ctx: Context from forward pass
g: Gradient from upstream
Returns:
Gradient with respect to input
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None)
_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule)
...@@ -7,4 +7,4 @@ from .attention import * ...@@ -7,4 +7,4 @@ from .attention import *
from .normalization import * from .normalization import *
from .quantization import * from .quantization import *
from .softmax import * from .softmax import *
from .transpose import * from .gemm import *
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for activation""" """JAX/TE custom ops for activation"""
from typing import Tuple, Sequence, Union, Callable from typing import Sequence, Union, Callable, Optional, Tuple
import operator import operator
from functools import reduce, partial from functools import reduce, partial
from packaging import version from packaging import version
...@@ -10,31 +10,38 @@ from packaging import version ...@@ -10,31 +10,38 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec
from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type from transformer_engine_jax import NVTE_Activation_Type
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import ( from .misc import (
check_valid_batch_dims,
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype, te_dtype_to_jax_dtype,
get_padded_spec, get_padded_spec,
is_ffi_enabled, check_valid_batch_dims,
multidim_transpose,
try_apply_delayed_scaling_2x_war,
should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding,
)
from .quantization import _jax_quantize_dbias, _jax_dbias, quantize_dbias
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
Quantizer,
QuantizeAxis,
DelayedScaleQuantizer,
ScalingMode,
) )
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP
if version.parse(jax.__version__) >= version.parse("0.5.0"): if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports from jax import ffi # pylint: disable=ungrouped-imports
else: else:
from jax.extend import ffi # pylint: disable=ungrouped-imports from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
__all__ = ["act_lu", "dact_lu", "act_lu_fp8"]
ActivationEnum = { ActivationEnum = {
...@@ -66,448 +73,1053 @@ def _convert_to_activation_function(fn_or_string): ...@@ -66,448 +73,1053 @@ def _convert_to_activation_function(fn_or_string):
raise ValueError(f"Unsupported {fn_or_string} to an activation function") raise ValueError(f"Unsupported {fn_or_string} to an activation function")
def _jax_act_lu(inputs, activation_type):
"""
JAX native activation implementation
"""
x = jnp.split(inputs, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
return x
class ActLuPrimitive(BasePrimitive): class ActLuPrimitive(BasePrimitive):
""" """
Activation Forward Primitive ActLu Primitive
""" """
name = "te_act_lu" name = "te_act_lu_ffi"
multiple_results = False multiple_results = True
impl_static_args = (
2,
3,
4,
5,
6,
7,
8,
9,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
impl_static_args = (1,)
@staticmethod @staticmethod
def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument def abstract(
x_aval,
scale_aval,
*,
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
""" """
act_lu abstract te_act_lu_p abstract
""" """
del act_enum, act_len, scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
out_shape = (
*x_aval.shape[:-2],
1,
x_aval.shape[-1],
)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(out_shape[:-2] + (out_shape[-1],), is_padded=not is_outer)
x_shape = x_aval.shape if len(rowwise_scale_inv_shape) > 1:
assert x_shape[-2] == 2 or x_shape[-2] == 1 rowwise_scale_inv_shape = (
hidden_size = x_shape[-1] rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
batch_shapes = x_shape[:-2] )
out_aval = x_aval if len(colwise_scale_inv_shape) > 1:
out_shape = (batch_shapes) + (hidden_size,) colwise_scale_inv_shape = (
out_aval = out_aval.update(shape=out_shape, dtype=dtype) colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
if is_2x:
colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
return out_aval return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
@staticmethod @staticmethod
def lowering(ctx, x, *, act_enum): def lowering(
ctx,
x,
scale,
*,
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
""" """
act_lu lowering rules te_gated_act_lu_p lowering rules
""" """
(x_aval,) = ctx.avals_in del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
if is_ffi_enabled(): assert scale_aval is None or scale_aval.dtype == jnp.float32
name = "te_act_lu_ffi"
out = ffi.ffi_lowering(name)(ctx, x, act_enum=act_enum)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_x_shape[-1] out = ffi.ffi_lowering(ActLuPrimitive.name)(
batch_size = reduce(operator.mul, ir_x_shape[:-2]) ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor(
(batch_size, hidden_size), in_dtype, in_dtype, act_enum
) )
out = custom_caller(ActLuPrimitive.name, args, opaque, False)
return out return out
@staticmethod @staticmethod
def impl(x, act_enum): def impl(
x,
scale,
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
to describe implementation
"""
del is_outer
assert ActLuPrimitive.inner_primitive is not None assert ActLuPrimitive.inner_primitive is not None
out = ActLuPrimitive.inner_primitive.bind(x, act_enum=act_enum)
return out out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
ActLuPrimitive.inner_primitive.bind(
x,
scale,
out_dtype=out_dtype,
act_enum=act_enum,
act_len=act_len,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False,
)
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(out.shape[:-2] + (out.shape[-1],), is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
rowwise_scale_inv_shape = (
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
)
if is_2x:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, act_enum): def batcher(
batched_args,
batch_dims,
*,
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
""" """
act_lu batcher to describe batch rules for vmap
""" """
del act_len, is_outer
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert ActLuPrimitive.outer_primitive is not None assert ActLuPrimitive.outer_primitive is not None
(inputs,) = batched_args x, scale = batched_args
(inputs_bdim,) = batch_dims x_bdim, scale_bdim = batch_dims
amax_bdim = scale_bdim
out_bdims = inputs_bdim out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims return (
ActLuPrimitive.outer_primitive.bind(
x,
scale,
out_dtype=out_dtype,
act_enum=act_enum,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
),
out_bdims,
)
@staticmethod @staticmethod
def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos): def infer_sharding_from_operands(
""" out_dtype,
act_lu infer_sharding_from_operands act_enum,
""" act_len,
del result_infos, act_enum # Unused. scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del (
out_dtype,
result_infos,
act_enum,
scale_dtype,
scale_shapes,
act_len,
is_outer,
) # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) out_spec = (*x_spec[:-2], None, x_spec[-2])
return out_sharding out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec)
else:
colwise_out_spec = out_spec
else:
colwise_out_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
)
return (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
)
@staticmethod @staticmethod
def partition(act_enum, mesh, arg_infos, result_infos): def partition(
""" out_dtype,
act_lu partitioning act_enum,
""" act_len,
del result_infos scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_spec = (*x_spec[:-1], x_spec[-1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) if act_len == 2 and x_spec[-1] is None:
# Ensure last axis is partitioned and not the gating axis
out_spec = (*x_spec[:-2], None, x_spec[-2])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec)
else:
colwise_out_spec = out_spec
else:
colwise_out_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")
def sharded_impl(x): if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
return ActLuPrimitive.impl(x, act_enum=act_enum) scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec))
arg_shardings = tuple(arg_shardings)
out_shardings = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
)
return mesh, sharded_impl, out_sharding, arg_shardings def sharded_impl(x, scale):
local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
ActLuPrimitive.impl(
x,
scale,
out_dtype=out_dtype,
act_enum=act_enum,
act_len=act_len,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True,
)
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
register_primitive(ActLuPrimitive) return (
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
global_updated_amax,
)
return mesh, sharded_impl, out_shardings, arg_shardings
def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
"""
act_lu wrapper
Return act_lu(inputs)
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
"""
if not ActLuPrimitive.enabled():
return _jax_act_lu(inputs, activation_type)
act_type_id = ActivationEnum[activation_type].value register_primitive(ActLuPrimitive)
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)
class DActLuPrimitive(BasePrimitive): class DActLuDBiasQuantizePrimitive(BasePrimitive):
""" """
Dgated ActLu Primitive DActLu DBias Cast Transpose Primitive
""" """
name = "te_dact_lu" name = "te_dact_dbias_quantize_ffi"
multiple_results = False multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
impl_static_args = (2,)
@staticmethod @staticmethod
def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument def abstract(
dz_aval,
x_aval,
scale_aval,
*,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
):
""" """
dact_lu abstract te_dact_dbias_quantize_p abstract
""" """
del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(dz_aval.dtype) dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype assert x_aval.dtype == dtype
for axis in range(len(dz_aval.shape) - 1): assert scale_aval.dtype == jnp.float32
assert dz_aval.shape[axis] == x_aval.shape[axis] ir_hidden_size = dz_aval.shape[-1]
assert x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1 gi_hidden_size = x_aval.shape[-1]
assert act_len * ir_hidden_size == gi_hidden_size
out_shape = x_aval.shape
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
i_hidden_size = dz_aval.shape[-1] rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
g_hidden_size = x_aval.shape[-1] scaling_mode
assert i_hidden_size == g_hidden_size ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
out_aval = x_aval
return out_aval scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
@staticmethod colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
def lowering(ctx, dz, x, *, act_enum): colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
"""
dact_lu lowering rules dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
""" wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
in_aval, gi_aval = ctx.avals_in if is_2x:
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] # Don't transpose output for MXFP8
assert gi_aval.dtype == in_aval.dtype if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
if is_ffi_enabled(): t_shape = out_shape
name = "te_dact_lu_ffi"
out = ffi.ffi_lowering(name)(ctx, dz, x, act_enum=act_enum)
else: else:
ir_in_type = ir.RankedTensorType(dz.type) t_shape = multidim_transpose(out_shape)
ir_in_shape = ir_in_type.shape colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
gi_type = ir.RankedTensorType(x.type) colwise_scale_inv_aval = jax.core.ShapedArray(
gi_shape = gi_type.shape shape=colwise_scale_inv_shape, dtype=scale_dtype
# assert ir_in_shape == gi_shape )
for axis in range(len(ir_in_shape) - 1):
assert ir_in_shape[axis] == gi_shape[axis]
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
g_hidden_size = gi_shape[-1]
assert i_hidden_size == g_hidden_size
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor(
(ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum
)
out = custom_caller(DActLuPrimitive.name, args, opaque, False)
return out if is_dbias:
dbias_shape = gi_hidden_size
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
(wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
scaling_mode,
is_2x,
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
dbias_aval,
wkspace_aval,
)
@staticmethod @staticmethod
def impl(dz, x, act_enum): def outer_abstract(*args, **kwargs):
""" """
dact_lu implementation te_dact_dbias_quantize_p outer abstract
""" """
assert DActLuPrimitive.inner_primitive is not None (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
dx = DActLuPrimitive.inner_primitive.bind(dz, x, act_enum=act_enum) DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
return dx )
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, act_enum): def lowering(
ctx,
dz,
x,
scale,
*,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
):
""" """
dact_lu batcher te_dact_dbias_quantize_p lowering rules
""" """
check_valid_batch_dims(batch_dims) del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
assert DActLuPrimitive.outer_primitive is not None dz_aval, x_aval, scale_aval = ctx.avals_in
dz, x = batched_args assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
_, x_bdim = batch_dims assert x_aval.dtype == dz_aval.dtype
assert scale_aval.dtype == jnp.float32
out_bdims = x_bdim return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)(
return DActLuPrimitive.outer_primitive.bind(dz, x, act_enum=act_enum), out_bdims ctx,
dz,
x,
scale,
scaling_mode=scaling_mode,
is_2x=is_2x,
is_dbias=is_dbias,
act_enum=int(act_enum),
)
@staticmethod @staticmethod
def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos): def impl(
dz,
x,
scale,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
):
""" """
dact_lu infer_sharding_from_operands te_dact_dbias_quantize_p impl
""" """
del result_infos, act_enum # Unused. del is_outer
act_lu_out_spec = get_padded_spec(arg_infos[1]) assert DActLuDBiasQuantizePrimitive.inner_primitive is not None
dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec)) (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
return dx_sharding DActLuDBiasQuantizePrimitive.inner_primitive.bind(
dz,
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
is_outer=False,
)
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) # Exclude wkspace
@staticmethod @staticmethod
def partition(act_enum, mesh, arg_infos, result_infos): def batcher(
batched_args,
batch_dims,
*,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
):
""" """
dact_lu partition to describe batch rules for vmap
""" """
del result_infos del is_outer
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) check_valid_batch_dims(batch_dims)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) assert DActLuDBiasQuantizePrimitive.outer_primitive is not None
out_shardings = dx_sharding dz, x, scale = batched_args
_, x_bdim, scale_bdim = batch_dims
out_bdims = (
x_bdim, # rowwise output
scale_bdim, # rowwise scale_inv
x_bdim, # colwise output
scale_bdim, # colwise scale_inv
scale_bdim, # amax
x_bdim, # dbias
)
return (
DActLuDBiasQuantizePrimitive.outer_primitive.bind(
dz,
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
),
out_bdims,
)
def sharded_impl(dz, x): @staticmethod
return DActLuPrimitive.impl(dz, x, act_enum=act_enum) def infer_sharding_from_operands(
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
mesh,
arg_infos,
result_infos,
):
del out_dtype, result_infos, act_enum
del scale_dtype, scale_shapes, is_dbias, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1])
return mesh, sharded_impl, out_shardings, arg_shardings out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec)
else:
colwise_x_spec = x_spec
else:
colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
)
return (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_shaprding,
)
register_primitive(DActLuPrimitive) @staticmethod
def partition(
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec), desc="out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec)
else:
colwise_x_spec = x_spec
else:
colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
)
def dact_lu( arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]] arg_shardings = (
) -> jnp.ndarray: arg_shardings[1],
""" arg_shardings[1],
dact_lu fusion wrapper *arg_shardings[2:],
Return dgated_act_lu(inputs) ) # dz and x are the same
""" out_shardings = (
if not DActLuPrimitive.enabled(): out_sharding,
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs) colwise_out_sharding,
return vjp_func(inputs)[0] scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_shaprding,
)
act_type_id = ActivationEnum[activation_type].value def sharded_impl(dz, x, scale):
return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id) (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
DActLuDBiasQuantizePrimitive.impl(
dz,
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
is_outer=True,
)
)
if is_dbias:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
else:
global_dbias = local_dbias
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
class ActLuFp8Primitive(BasePrimitive): return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
"""
ActLu FP8 Primitive
"""
name = "te_act_lu_fp8" return mesh, sharded_impl, out_shardings, arg_shardings
multiple_results = True
impl_static_args = (4, 5) # out_dtype, act_enum
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract( register_primitive(DActLuDBiasQuantizePrimitive)
x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, act_enum
): # pylint: disable=unused-argument
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
""" """
te_act_lu_p abstract JAX native activation implementation
""" """
dtype = dtypes.canonicalize_dtype(x_aval.dtype) x = jnp.split(inputs, len(activation_type), axis=-1)
# Currently only support casting to E4M3 only in C side. acts = []
assert out_dtype == jnp.float8_e4m3fn for idx, act_fn in enumerate(activation_type):
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] x_i = _convert_to_activation_function(act_fn)(x[idx])
assert amax_aval.dtype == jnp.float32 acts.append(x_i)
assert scale_aval.dtype == jnp.float32 x = reduce(operator.mul, acts)
assert scale_inv_aval.dtype == jnp.float32 if quantizer:
return quantizer.quantize(x)
assert x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2 return x
hidden_size = x_aval.shape[-1]
batch_shape = x_aval.shape[:-2]
out_shape = (batch_shape) + (hidden_size,)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out_aval, updated_amax_aval
@staticmethod def _jax_quantize_dact_dbias(
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum): dz: jnp.ndarray,
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
is_dbias: bool = True,
quantizer: Optional[Quantizer] = None,
):
""" """
te_gated_act_lu_p lowering rules JAX implementation of dact_lu and dbias with optional quantization
""" """
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in _, vjp_func = jax.vjp(
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_act_lu_fp8_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
ctx, x, amax, scale, scale_inv, act_enum=act_enum
) )
(dx,) = vjp_func(dz.astype(jnp.float32))
dbias = None
if is_dbias:
dbias = _jax_dbias(dx).astype(x.dtype)
if quantizer is not None:
dx = quantizer.quantize(dx, dq_dtype=x.dtype)
else: else:
ir_x_type = ir.RankedTensorType(x.type) dx = dx.astype(x.dtype)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor(
(batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)
out = custom_caller( return dx, dbias
ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
)
return out
@staticmethod def act_lu(
def impl(x, amax, scale, scale_inv, out_dtype, act_enum): x: jnp.ndarray,
""" activation_type: Sequence[Union[str, Callable]],
to describe implementation quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization.
Args:
x: Input tensor to be processed.
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
If quantizer is None:
The activated input tensor with the same dtype as input.
If quantizer is provided:
A ScaledTensor containing the quantized activated input.
""" """
assert ActLuFp8Primitive.inner_primitive is not None act_type_id = ActivationEnum[activation_type].value
out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum if not ActLuPrimitive.enabled():
return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support 2x quantization for DelayedScaling yet
war_output = try_apply_delayed_scaling_2x_war(
f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
)
if war_output is not None:
return war_output
scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-1], x.shape[-1] // len(activation_type))
if quantizer is None:
x = x.reshape((-1, len(activation_type), x.shape[-1] // len(activation_type)))
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
x,
scale,
out_dtype=x.dtype,
act_enum=act_type_id,
act_len=len(activation_type),
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True,
) )
return out, updated_amax out = out.reshape(output_shape)
return out
@staticmethod if isinstance(quantizer, DelayedScaleQuantizer):
def batcher(batched_args, batch_dims, *, out_dtype, act_enum): scale = quantizer.scale
"""
to describe batch rules for vmap x = x.reshape((*x.shape[:-1], len(activation_type), x.shape[-1] // len(activation_type)))
""" (
check_valid_batch_dims(batch_dims) rowwise_casted_output,
assert ActLuFp8Primitive.outer_primitive is not None colwise_casted_output,
x, amax, scale, scale_inv = batched_args rowwise_scale_inv,
x_bdim, amax_bdim, _, _ = batch_dims colwise_scale_inv,
updated_amax,
) = ActLuPrimitive.outer_primitive.bind(
x,
scale,
out_dtype=quantizer.q_dtype,
act_enum=act_type_id,
act_len=len(activation_type),
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(output_shape),
is_outer=True,
)
out_bdims = x_bdim, amax_bdim rowwise_casted_output = rowwise_casted_output.reshape(output_shape)
return ( if len(rowwise_scale_inv.shape) > 1:
ActLuFp8Primitive.outer_primitive.bind( rowwise_scale_inv = jnp.squeeze(rowwise_scale_inv, axis=-2) # Remove act axis
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum if quantizer.q_axis in (QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE):
), colwise_output_shape = output_shape
out_bdims, if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
colwise_output_shape = multidim_transpose(output_shape)
colwise_casted_output = colwise_casted_output.reshape(colwise_output_shape)
if len(colwise_scale_inv.shape) > 1:
colwise_scale_inv = jnp.squeeze(colwise_scale_inv, axis=-2) # Remove act axis
quantizer.update(updated_amax)
return ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
) )
@staticmethod
def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod def quantize_dact_dbias(
def partition(out_dtype, act_enum, mesh, arg_infos, result_infos): dz: jnp.ndarray,
del result_infos x: jnp.ndarray,
x_spec = get_padded_spec(arg_infos[0]) activation_type: Sequence[Union[str, Callable]] = ("gelu",),
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) is_dbias: bool = True,
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) quantizer: Optional[Quantizer] = None,
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) ) -> Tuple[ScaledTensor, jnp.ndarray]:
out_shardings = (out_sharding, amax_sharding) """Compute gradients of activation and bias with optional quantization.
Args:
dz: Gradient of the output with respect to the activation output.
x: Input tensor that was processed by the forward pass.
Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
- The gradient of the activation with respect to the input.
- The gradient of the activation with respect to the bias.
"""
def sharded_impl(x, amax, scale, scale_inv): if not DActLuDBiasQuantizePrimitive.enabled():
local_x, local_amax = ActLuFp8Primitive.impl( return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, global_updated_amax # TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
return mesh, sharded_impl, out_shardings, arg_shardings # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = quantize_dact_dbias(
dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=None
)
return quantize_dbias(out, is_dbias=True, quantizer=quantizer)
is_gated = len(activation_type) == 2
# TE/common does not support DelayedScaling2x for gated-act yet
if is_gated:
war_output = try_apply_delayed_scaling_2x_war(
f=quantize_dact_dbias,
dz=dz,
x=x,
activation_type=activation_type,
is_dbias=is_dbias,
quantizer=quantizer,
)
if war_output is not None:
return war_output
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
if quantizer is None:
output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
scale_shapes=((), ()), # unused
is_dbias=False,
act_enum=act_type_id,
act_len=len(activation_type),
is_outer=True,
)
dbias = None
if is_dbias:
dbias = _jax_dbias(output).astype(x.dtype)
return output.astype(x.dtype), dbias
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
# TE/common dact_dbias_quantize does not support gated act yet
if is_dbias and is_gated:
dgated = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
)
# TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
out, dbias = _jax_quantize_dbias(dgated, quantizer=quantizer, dq_dtype=x.dtype)
else:
out, dbias = quantize_dbias(
dgated,
quantizer=quantizer,
is_dbias=True,
dq_dtype=x.dtype,
)
return out, dbias
out_shape = x.shape
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
dz,
x,
scale,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(out_shape),
is_dbias=is_dbias,
act_enum=act_type_id,
act_len=len(activation_type),
is_outer=True,
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
out = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
)
register_primitive(ActLuFp8Primitive) return out, dbias
def act_lu_fp8( def dact_lu(
dz: jnp.ndarray,
x: jnp.ndarray, x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: quantizer: Optional[Quantizer] = None,
""" ) -> Union[jnp.ndarray, ScaledTensor]:
act wrapper
Return FP8(act_lu(x))
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
""" """
if not ActLuFp8Primitive.enabled(): Backward pass for activation with optional quantization.
act_lu_output = _jax_act_lu(x, activation_type)
casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
return casted_output, updated_amax
act_type_id = ActivationEnum[activation_type].value Args:
return ActLuFp8Primitive.outer_primitive.bind( dz: Gradient tensor from upstream.
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient.
Returns:
The gradient of the activation with respect to the input.
"""
output, _ = quantize_dact_dbias(
dz=dz,
x=x,
activation_type=activation_type,
is_dbias=False,
quantizer=quantizer,
) )
return output
...@@ -13,8 +13,6 @@ from packaging import version ...@@ -13,8 +13,6 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, lax from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax import transformer_engine_jax
...@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import ( ...@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import (
) )
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import ( from .misc import (
check_valid_batch_dims, check_valid_batch_dims,
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
get_padded_spec, get_padded_spec,
get_cudnn_version, get_cudnn_version,
is_ffi_enabled,
) )
from ..sharding import ( from ..sharding import (
global_mesh_resource, global_mesh_resource,
...@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
Fused Attention Forward Primitive Fused Attention Forward Primitive
""" """
name = "te_fused_attn_forward" name = "te_fused_attn_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (13,) impl_static_args = (13,)
inner_primitive = None inner_primitive = None
...@@ -400,9 +396,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -400,9 +396,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled(): return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx, ctx,
q, q,
k, k,
...@@ -436,54 +430,6 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -436,54 +430,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
window_size_left=config.window_size[0], window_size_left=config.window_size[0],
window_size_right=config.window_size[1], window_size_right=config.window_size[1],
) )
else:
operands = [
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod @staticmethod
def impl( def impl(
...@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
Fused Attention Backward Primitive Fused Attention Backward Primitive
""" """
name = "te_fused_attn_backward" name = "te_fused_attn_backward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (16,) impl_static_args = (16,)
inner_primitive = None inner_primitive = None
...@@ -813,9 +759,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -813,9 +759,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled(): return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
name = "te_fused_attn_backward_ffi"
out = ffi.ffi_lowering(name)(
ctx, ctx,
q, q,
k, k,
...@@ -852,57 +796,6 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -852,57 +796,6 @@ class FusedAttnBwdPrimitive(BasePrimitive):
window_size_left=config.window_size[0], window_size_left=config.window_size[0],
window_size_right=config.window_size[1], window_size_right=config.window_size[1],
) )
else:
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod @staticmethod
def impl( def impl(
......
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import re import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from functools import partial from functools import partial
from packaging import version
from jax.extend import core from jax.extend import core
from jax.interpreters import xla, mlir from jax.interpreters import xla, mlir
...@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning ...@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src import dispatch from jax._src import dispatch
import jax
import transformer_engine_jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
class BasePrimitive(metaclass=ABCMeta): class BasePrimitive(metaclass=ABCMeta):
""" """
...@@ -120,3 +129,7 @@ def register_primitive(cls): ...@@ -120,3 +129,7 @@ def register_primitive(cls):
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
) )
cls.outer_primitive = outer_p cls.outer_primitive = outer_p
for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom call"""
from dataclasses import dataclass
from enum import IntEnum
from packaging import version
import jax
from jax.interpreters import mlir
import transformer_engine_jax
from .misc import is_ffi_enabled
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
try:
from jaxlib.hlo_helpers import custom_call
except ImportError:
# Newer JAX changed its API. But we want to support a few JAX
# version, so we still need this import.
pass
class CustomCallAPIVersion(IntEnum):
"""Enum for selecting between old and new custom call registration API"""
OPAQUE = 0
FFI = 1
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
@dataclass
class CustomCallArgsWrapper:
"""
wrapper of XLA custom call args
"""
def __init__(
self,
output_types,
operands,
operand_shapes,
operand_specific_layouts=None,
output_specific_layouts=None,
):
self.output_types = output_types
self.operands = operands
self.operand_layouts = CustomCallArgsWrapper.generate_layouts(
operand_shapes, operand_specific_layouts
)
output_shapes = [x.shape for x in output_types]
self.output_layouts = CustomCallArgsWrapper.generate_layouts(
output_shapes, output_specific_layouts
)
@staticmethod
def generate_layouts(shapes, specific_layouts):
"""
setup layouts for XLA custom call
"""
def default_layout(shape):
return range(len(shape) - 1, -1, -1)
if specific_layouts is None:
specific_layouts = {}
layouts = []
for idx, shape in enumerate(shapes):
if idx in specific_layouts:
layouts.append(specific_layouts[idx])
else:
layouts.append(default_layout(shape))
return layouts
def custom_caller(name, args, opaque, has_side_effect, **kwargs):
"""
XLA custom call warpper
"""
if hasattr(mlir, "custom_call"):
out = mlir.custom_call(
name,
result_types=args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
).results
else:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out = custom_call( # pylint: disable=too-many-function-args
name,
args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
)
return out
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""
from typing import Tuple, Sequence, Union, Dict, List
from functools import partial, reduce
import operator
from transformer_engine_jax import get_device_compute_capability
import jax
import jax.numpy as jnp
from .base import BasePrimitive, register_primitive
from ..quantize import (
ScaledTensor,
ScalingMode,
Quantizer,
QuantizeConfig,
noop_quantizer_set,
)
__all__ = ["gemm", "grouped_gemm"]
num_cublas_streams = 4
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if get_device_compute_capability(0) >= 90:
return 33_554_432
return 4_194_304
class GroupedGemmPrimitive(BasePrimitive):
"""
Primitive for grouped GEMM
"""
name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
lhs_contig_aval,
lhs_scale_contig_aval,
rhs_contig_aval,
rhs_scale_contig_aval,
bias_contig_aval,
dim_list_aval,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
):
del lhs_contig_aval, lhs_scale_contig_aval
del rhs_contig_aval, rhs_scale_contig_aval
del bias_contig_aval, dim_list_aval
del num_gemms, scaling_mode
out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype)
wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8)
return (out_flat_aval, wkspace_aval)
@staticmethod
def outer_abstract(*args, **kwargs):
(out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
return out_aval
@staticmethod
def lowering(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
del out_dtype, out_flat_size
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=int(scaling_mode),
)
@staticmethod
def impl(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
)
return out[0] # out is [out_flat, wkspace], only return out_flat
register_primitive(GroupedGemmPrimitive)
def _shape_normalization(x, dimension_numbers, already_transposed: bool = False):
orig_order = list(range(x.ndim))
contracting_dims, batch_dims = dimension_numbers
contracting_order = [d for d in orig_order if d in contracting_dims]
batch_order = [d for d in orig_order if d in batch_dims]
non_contracting_order = [
d for d in orig_order if d not in contracting_dims and d not in batch_dims
]
batch_shape = [x.shape[d] for d in batch_order]
rows_shape = [x.shape[d] for d in non_contracting_order]
cols_shape = [x.shape[d] for d in contracting_order]
new_order = batch_order + non_contracting_order + contracting_order
rows, cols, batches = (
reduce(operator.mul, rows_shape, 1),
reduce(operator.mul, cols_shape, 1),
reduce(operator.mul, batch_shape, 1),
)
# Remove this transpose when non-TN dot is supported
if not already_transposed:
t = jnp.transpose(x, new_order)
else:
t = x
return jnp.reshape(t, (batches, rows, cols))
def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)
def _dequantize(x, scale_inv, dq_dtype):
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(
jax.jit,
static_argnums=(
2,
3,
4,
),
)
def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T")
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
)
return out_3d
def _jax_gemm_delayed_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs.data.shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs.data.shape, rhs_contract)
precision = (
jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
)
out_3d = __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
# Reshape [B, M, N] -> [..., M, N]
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
def _jax_gemm_mxfp8_1d(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""
JAX GEMM for MXFP8 via scaled_matmul
"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
expected_lhs_is_colwise = lhs_contract[-1] != lhs.data.ndim - 1
expected_rhs_is_colwise = rhs_contract[-1] != rhs.data.ndim - 1
assert lhs.is_colwise is expected_lhs_is_colwise, (
f"LHS with unexpected quantize dimension.\nExpect is_colwise={expected_lhs_is_colwise}, got"
f" {lhs.is_colwise}"
)
assert rhs.is_colwise is expected_rhs_is_colwise, (
f"RHS with unexpected quantize dimension.\nExpect is_colwise={expected_rhs_is_colwise}, got"
f" {rhs.is_colwise}"
)
# Reshape + Transpose (if needed)
# [..., M, K] -> [1, reduce(..., M), K]
# [..., K, M] -> [1, reduce(..., M), K]
lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch))
rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch))
lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch))
rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch))
# Slice out the padding as scaled_matmul does not support padded scales yet
lhs_scale_3d = jnp.asarray(lhs_scale_3d[:, : lhs_3d.shape[1], : int(lhs_3d.shape[2] / 32)])
rhs_scale_3d = jnp.asarray(rhs_scale_3d[:, : rhs_3d.shape[1], : int(rhs_3d.shape[2] / 32)])
# JAX scaled_matmul only supports NT now (TN-gemm)
# * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = scaled_matmul_wrapper(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
)
# Reshape [1, reduce(..., M), N] -> [..., M, N]
lhs_remain_shape = tuple(
lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract
)
rhs_remain_shape = tuple(
rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract
)
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
def _jax_gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set,
) -> jnp.ndarray:
"""
FP8 GEMM via JAX
"""
dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
return _jax_gemm_fp8_impl(lhs, rhs)
if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor):
if quantizer_set != noop_quantizer_set:
assert type(quantizer_set.x) is type(quantizer_set.kernel)
(((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
# Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm)
lhs_q = quantizer_set.x.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = quantizer_set.kernel.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return _jax_gemm_fp8_impl(lhs_q, rhs_q)
if (
isinstance(lhs, jnp.ndarray)
and isinstance(rhs, jnp.ndarray)
and quantizer_set == noop_quantizer_set
):
return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype)
raise NotImplementedError("Not supporting multiplication of ScaledTensor and jnp.array")
def gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set,
) -> jnp.ndarray:
"""General matrix multiplication with optional quantization.
Args:
lhs: First input matrix.
rhs: Second input matrix.
contracting_dims: Tuple of two sequences representing the contracting dimensions.
The first sequence represents the contracting dimensions of the first matrix,
and the second sequence represents the contracting dimensions of the second matrix.
quantizer_set: Set of quantizers for FP8 quantization of the output.
If None, no quantization is applied and the output has the same dtype as the inputs.
Returns:
If quantizer_set is None:
The matrix multiplication result.
Shape: (M, N)
Dtype: Same as input dtype
If quantizer_set is provided:
A ScaledTensor containing the quantized matrix multiplication result.
"""
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
def swizzled_scale(scales):
"""Swizzle the scale tensor for FP8 GEMM"""
assert scales.ndim == 2
rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
return scales
def grouped_gemm(
lhs_list: List[Union[jnp.ndarray, ScaledTensor]],
rhs_list: List[Union[jnp.ndarray, ScaledTensor]],
contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]],
bias_list: List[jnp.ndarray] = None,
) -> List[jnp.ndarray]:
"""Grouped GEMM for multiple pairs of tensors."""
assert (
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
# Flatten inputs and save their shapes
num_gemms = len(lhs_list)
out_flat_size = 0
dims = []
lhs_contig_ = []
rhs_contig_ = []
lhs_scale_inv_contig_ = []
rhs_scale_inv_contig_ = []
bias_contig_ = []
out_offsets = []
remain_shape_list = []
num_gemms = len(lhs_list)
for i in range(num_gemms):
lhs = lhs_list[i]
rhs = rhs_list[i]
contracting_dims = contracting_dims_list[i]
dim_nums = (contracting_dims, ((), ()))
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
scaling_mode = lhs.scaling_mode
lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
out_dtype = lhs.dtype
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
# Note: if _shape_normalization() is updated to support non-TN, need to update here
# already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
# x.shape = [D1, D2]
# contracting_dims = (1, ) --> output.shape = [1, D1, D2]
# contracting_dims = (0, ) --> output.shape = [1, D2, D1]
bm = lhs_remain_shape[0]
bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1]
kr = rhs_3d.shape[-1]
remain_shape_list.append(((bm,), (bn,)))
assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})"
k = kl
if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0):
print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(
f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples"
" of 16"
)
assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0
dims.append((bm, bn, k))
lhs_contig_.append(lhs_3d.reshape(-1))
rhs_contig_.append(rhs_3d.reshape(-1))
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1))
if bias_list is not None:
bias_contig_.append(bias_list[i].reshape(-1))
out_flat_size += bm * bn
out_offsets.append(out_flat_size)
lhs_contig = jnp.concatenate(lhs_contig_)
rhs_contig = jnp.concatenate(rhs_contig_)
lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_)
rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_)
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
dim_list = jnp.array(dims, dtype=jnp.int32)
# Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=scaling_mode,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
)
# Split the output back into tensors
out_offsets = jnp.array(out_offsets)
out_flat_list = jnp.split(out_contig, out_offsets[:-1])
out_tensors = []
for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list):
out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape))
return out_tensors
...@@ -11,14 +11,17 @@ from packaging.version import Version as PkgVersion ...@@ -11,14 +11,17 @@ from packaging.version import Version as PkgVersion
import numpy as np import numpy as np
import jax.numpy as jnp import jax
from jax import dtypes from jax import dtypes
import jax.numpy as jnp
from jax.interpreters.mlir import dtype_to_ir_type from jax.interpreters.mlir import dtype_to_ir_type
from transformer_engine_jax import DType as TEDType
import transformer_engine_jax import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis
TEDType = transformer_engine_jax.DType
def te_dtype_to_jax_dtype(te_dtype): def te_dtype_to_jax_dtype(te_dtype):
...@@ -104,7 +107,7 @@ def normalize_axis_boundary(axis, ndim): ...@@ -104,7 +107,7 @@ def normalize_axis_boundary(axis, ndim):
return axis if axis >= 0 else ndim + axis return axis if axis >= 0 else ndim + axis
def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1):
""" """
te_cast_transpose_p multi-dims transpose te_cast_transpose_p multi-dims transpose
...@@ -158,17 +161,6 @@ def jax_version_meet_requirement(version: str): ...@@ -158,17 +161,6 @@ def jax_version_meet_requirement(version: str):
return jax_version >= jax_version_required return jax_version >= jax_version_required
def is_ffi_enabled():
"""
Helper function checking if XLA Custom Call with FFI is enabled
"""
is_supported = jax_version_meet_requirement("0.4.35")
# New APIs with FFI are enabled by default
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
return is_supported and is_enabled
def get_xla_flag(flag: str, default=None, cast=str): def get_xla_flag(flag: str, default=None, cast=str):
""" """
Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
...@@ -189,3 +181,86 @@ def get_xla_flag(flag: str, default=None, cast=str): ...@@ -189,3 +181,86 @@ def get_xla_flag(flag: str, default=None, cast=str):
if name == flag: if name == flag:
return True return True
return default return default
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
"""
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
calculate dbias separately. This function checks if the workaround should be applied.
"""
arch_l_100 = False
for local_gpu_id in range(len(jax.local_devices())):
if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100:
arch_l_100 = True
break
return (
quantizer is not None
and quantizer.q_axis == QuantizeAxis.ROWWISE
and arch_l_100
and is_dbias
)
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
"""
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
If 'f' returns a tuple, the first output must be the only ScaledTensor output.
@param f: function to call
@param args: positional arguments to pass to 'f'
@param quantizer: quantizer to use
@param kwargs: keyword arguments to pass to 'f'
@return: the output of 'f' with the colwise output calculated
"""
should_apply_war = (
quantizer is not None
and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
and quantizer.is_2x2x()
)
if not should_apply_war:
return None
# 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX
quantizer.q_axis = QuantizeAxis.ROWWISE
rowwise = f(*args, **kwargs, quantizer=quantizer)
other_outputs = None
if isinstance(rowwise, tuple):
other_outputs = rowwise[1:]
rowwise = rowwise[0]
quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE
colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1)))
output_2x = ScaledTensorFactory.create(
data=rowwise.data,
scale_inv=rowwise.scale_inv,
colwise_data=colwise_data,
colwise_scale_inv=rowwise.scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=rowwise.dq_dtype,
q_axis=QuantizeAxis.ROWWISE_COLWISE,
layout=quantizer.get_layout(),
)
if other_outputs is not None:
return (output_2x,) + other_outputs
return output_2x
class NamedSharding(jax.sharding.NamedSharding):
"""
Wrapper around jax.sharding.NamedSharding that adds a string description field as metadata for easier debugging.
"""
def __init__(self, *args, desc: str = None, **kwargs):
super().__init__(*args, **kwargs)
self.desc = desc
def __repr__(self):
return f"NamedSharding({self.mesh}, {self.spec}, desc={self.desc})"
def duplicate_with_new_description(self, desc: str):
"""
Create a new NamedSharding with the same mesh and spec but with a new description.
"""
return NamedSharding(self.mesh, self.spec, desc=desc)
...@@ -2,33 +2,38 @@ ...@@ -2,33 +2,38 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for normalization""" """JAX/TE custom ops for normalization"""
import operator
import os import os
import warnings import warnings
from functools import partial, reduce, cache import operator
from functools import partial, cache, reduce
from typing import Optional, Union
from packaging import version from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import NVTE_Norm_Type
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import ( from .misc import (
get_padded_spec, get_padded_spec,
check_valid_batch_dims, check_valid_batch_dims,
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
is_ffi_enabled, NamedSharding,
) )
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
Quantizer,
QuantizeAxis,
DelayedScaleQuantizer,
ScalingMode,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"): if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports from jax import ffi # pylint: disable=ungrouped-imports
...@@ -41,8 +46,8 @@ __all__ = [ ...@@ -41,8 +46,8 @@ __all__ = [
"layernorm_bwd", "layernorm_bwd",
"rmsnorm_fwd", "rmsnorm_fwd",
"rmsnorm_bwd", "rmsnorm_bwd",
"layernorm_fwd_fp8", "normalization_fwd",
"rmsnorm_fwd_fp8", "normalization_bwd",
] ]
...@@ -58,325 +63,520 @@ def get_backward_sm_margin(): ...@@ -58,325 +63,520 @@ def get_backward_sm_margin():
return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
class LayerNormFwdPrimitive(BasePrimitive): class NormFwdPrimitive(BasePrimitive):
""" """
Layer Normalization Forward Primitive Layer Normalization Forward FP8 Primitive
""" """
name = "te_layernorm_forward" name = "te_norm_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (3, 4) # zero_centered_gamma, epsilon impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(x_aval, gamma_aval, beta_aval, **kwargs): def abstract(
x_aval,
scale_aval,
gamma_aval,
beta_aval,
*,
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
""" """
LayerNorm fwd inner primitive abstract LayerNorm fwd inner primitive abstract
""" """
del scale_shapes
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
mu_rsigama_dtype = jnp.float32 mu_rsigama_dtype = jnp.float32
out_aval = x_aval if norm_type == NVTE_Norm_Type.LayerNorm:
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
assert gamma_aval.size == beta_aval.size assert gamma_aval.size == beta_aval.size
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size x_aval.size // gamma_aval.size, # batch size
hidden_size, gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype jax_dtype_to_te_dtype(x_aval.dtype), # itype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16) jax_dtype_to_te_dtype(out_dtype),
True, norm_type,
kwargs["zero_centered_gamma"], scaling_mode.value,
kwargs["epsilon"], zero_centered_gamma,
epsilon,
get_forward_sm_margin(), get_forward_sm_margin(),
is_2x,
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x_aval.shape, is_padded=not is_outer
) )
wkspace_aval = out_aval.update(
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
colwise_out_aval = jax.core.ShapedArray(
shape=x_aval.shape if is_2x else (1,), dtype=out_dtype
)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
) )
return out_aval, mu_aval, rsigma_aval, wkspace_aval return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
wkspace_aval,
)
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
""" """
LayerNorm fwd outer primitive abstract LayerNorm fwd outer primitive abstract
""" """
out_aval, mu_aval, rsigma_aval, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs) (
return out_aval, mu_aval, rsigma_aval out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
_,
) = NormFwdPrimitive.abstract(*args, **kwargs)
return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
)
@staticmethod @staticmethod
def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): def lowering(
ctx,
x,
scale,
gamma,
beta,
*,
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
""" """
LayerNorm fwd lowering rules LayerNorm fwd lowering rules
""" """
x_aval, gamma_aval, beta_aval = ctx.avals_in del out_dtype, scale_dtype, scale_shapes, is_outer
assert gamma_aval.dtype == beta_aval.dtype x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
g_type = ir.RankedTensorType(gamma.type) g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape g_shape = g_type.shape
if norm_type == NVTE_Norm_Type.LayerNorm:
assert gamma_aval.dtype == beta_aval.dtype
b_type = ir.RankedTensorType(beta.type) b_type = ir.RankedTensorType(beta.type)
b_shape = b_type.shape b_shape = b_type.shape
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_ffi"
sm_margin = get_forward_sm_margin() sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name)( return ffi.ffi_lowering(NormFwdPrimitive.name)(
ctx, ctx,
x, x,
scale,
gamma, gamma,
beta, beta,
norm_type=norm_type.value,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
eps=epsilon, epsilon=epsilon,
sm_margin=sm_margin, sm_margin=sm_margin,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
) )
else:
# Output shape is same as the input shape, but the output type is same as the weight type.
# See ln_api.cpp
output_type = g_type.element_type
ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get()
out_shape = x_shape
hidden_size = reduce(operator.mul, g_shape)
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [x, gamma, beta]
operand_shapes = [x_shape, g_shape, b_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor( @staticmethod
batch_size, def impl(
hidden_size, x,
wkspace_aval.size, scale,
jax_dtype_to_te_dtype(x_aval.dtype), gamma,
jax_dtype_to_te_dtype(gamma_aval.dtype), beta,
jax_dtype_to_te_dtype(wkspace_aval.dtype), norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sm_margin, out_dtype,
) scaling_mode,
is_2x,
out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False) scale_dtype,
scale_shapes,
return out is_outer,
):
@staticmethod
def impl(x, gamma, beta, zero_centered_gamma, epsilon):
""" """
to describe implementation to describe implementation
""" """
assert LayerNormFwdPrimitive.inner_primitive is not None del is_outer
out, mu, rsigma, _ = LayerNormFwdPrimitive.inner_primitive.bind( assert NormFwdPrimitive.inner_primitive is not None
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon (
) out,
return out, mu, rsigma colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
mu,
rsigma,
_,
) = NormFwdPrimitive.inner_primitive.bind(
x,
scale,
gamma,
beta,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x.shape, is_padded=False
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv = scale_inv.flatten()[
: reduce(operator.mul, rowwise_scale_inv_shape)
].reshape(rowwise_scale_inv_shape)
if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape)
].reshape(colwise_scale_inv_shape)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
mu,
rsigma,
) # Exclude wkspace
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon): def batcher(
batched_args,
batch_dims,
*,
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
""" """
to describe batch rules for vmap to describe batch rules for vmap
""" """
del is_outer
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert LayerNormFwdPrimitive.outer_primitive is not None assert NormFwdPrimitive.outer_primitive is not None
x, gamma, beta = batched_args x, scale, gamma, beta = batched_args
x_bdim, _, _ = batch_dims x_bdim, scale_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, x_bdim out_bdims = (
x_bdim, # rowwise output
scale_bdim, # rowwise scale_inv
x_bdim, # colwise output
scale_bdim, # colwise scale_inv
scale_bdim, # amax
x_bdim, # mu
x_bdim, # rsigma
)
return ( return (
LayerNormFwdPrimitive.outer_primitive.bind( NormFwdPrimitive.outer_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon scale,
x,
gamma,
beta,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
), ),
out_bdims, out_bdims,
) )
@staticmethod @staticmethod
def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): def infer_sharding_from_operands(
del zero_centered_gamma, epsilon, result_infos norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, scale_shapes, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, " "Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance." "and hurt performance."
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) out_sharding = NamedSharding(
return (out_sharding, mu_sharding, rsigma_sharding) mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out"
)
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
)
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
output = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding, # rowwise
scale_inv_sharding, # colwise
amax_sharding,
mu_sharding,
rsigma_sharding,
)
return output
@staticmethod @staticmethod
def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): def partition(
del result_infos norm_type,
x_spec, g_spec, b_spec = map(get_padded_spec, arg_infos) zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[2])
b_spec = get_padded_spec(arg_infos[3])
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, " "Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance." "and hurt performance."
) )
if g_spec[-1] is not None: if g_spec[-1] is not None:
warnings.warn( warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter gamma " f"{NormFwdPrimitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! " "Enforcing no sharding of parameters hidden dim! "
) )
if b_spec[-1] is not None: if b_spec[-1] is not None:
warnings.warn( warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter beta " f"{NormFwdPrimitive.name} does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! " "Enforcing no sharding of parameters hidden dim! "
) )
x_sharding = NamedSharding(
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x"
g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (x_sharding, g_sharding, b_sharding)
out_shardings = (out_sharding, mu_sharding, rsigma_sharding)
impl = partial(
LayerNormFwdPrimitive.impl, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
) )
return mesh, impl, out_shardings, arg_shardings g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma")
b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta")
out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out")
register_primitive(LayerNormFwdPrimitive) if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
def _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps): )
""" else:
JAX native layernorm implementation colwise_out_sharding = NamedSharding(
""" mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
x_ = jnp.asarray(x, jnp.float32) )
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) rsigma_sharding = NamedSharding(
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) mesh,
if zero_centered_gamma: PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]),
gamma += 1.0 desc="NormFwdPrimitive.rsigma",
return jnp.asarray(normed_input * gamma + beta).astype(x.dtype) )
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
def _jax_rmsnorm(x, gamma, zero_centered_gamma, eps): mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
"""
JAX native rmsnorm implementation scale_sharding = NamedSharding(
""" mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale"
x_ = jnp.asarray(x, jnp.float32) )
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) scale_inv_sharding = scale_sharding.duplicate_with_new_description(
normed_input = x_ * jax.lax.rsqrt(var + eps) "NormFwdPrimitive.scale_inv"
if zero_centered_gamma: )
gamma += 1.0 amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
return jnp.asarray(normed_input * gamma).astype(x.dtype) if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
def _jax_layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, eps): )
"""
JAX native layernorm fp8 implementation arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding)
""" out_shardings = (
x_ = jnp.asarray(x, jnp.float32) out_sharding,
mean = jnp.mean(x_, axis=-1, keepdims=True) colwise_out_sharding,
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) scale_inv_sharding, # rowwise
rsigma = jax.lax.rsqrt(var + eps) scale_inv_sharding, # colwise
normed_input = (x_ - mean) * rsigma amax_sharding,
if zero_centered_gamma: mu_sharding,
gamma += 1.0 rsigma_sharding,
output = normed_input * gamma + beta )
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1), updated_amax def sharded_impl(x, scale, gamma, beta):
# expect tp and dp giving same shape, or tp being same shape as global
(
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
local_amax,
local_mu,
local_rsigma,
) = NormFwdPrimitive.impl(
x,
scale,
gamma,
beta,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True,
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
return (
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
global_updated_amax,
local_mu,
local_rsigma,
)
def _jax_rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps): return mesh, sharded_impl, out_shardings, arg_shardings
"""
JAX native rmsnorm fp8 implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + eps)
normed_input = x_ * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax
def layernorm_fwd( register_primitive(NormFwdPrimitive)
x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float
):
"""
Wrapper for TE layernorm fwd
"""
if not LayerNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
mu = jnp.mean(x_, axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon)
return (
_jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon),
jnp.squeeze(mu, axis=-1),
jnp.squeeze(rsigma, axis=-1),
)
return LayerNormFwdPrimitive.outer_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
class LayerNormBwdPrimitive(BasePrimitive): class NormBwdPrimitive(BasePrimitive):
""" """
Layer Normalization Backward Primitive Layer Normalization Backward Primitive
""" """
name = "te_layernorm_backward" name = "te_norm_backward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (5, 6) # zero_centered_gamma, epsilon impl_static_args = (5, 6) # norm_type, zero_centered_gamma
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_centered_gamma):
""" """
Layernorm bwd inner primitive abstract bwd inner primitive abstract
""" """
w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)
assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
assert dz_aval.shape == x_aval.shape assert dz_aval.shape == x_aval.shape
if norm_type == NVTE_Norm_Type.LayerNorm:
mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
assert mu_dtype == rsigma_dtype == jnp.float32 assert mu_dtype == rsigma_dtype == jnp.float32
dx_aval = dz_aval dx_aval = dz_aval
dgamma_aval = dbeta_aval = gamma_aval dgamma_aval = dbeta_aval = gamma_aval
if norm_type != NVTE_Norm_Type.LayerNorm:
dbeta_aval = dbeta_aval.update(shape=(1,))
(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_norm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
True, norm_type,
kwargs["zero_centered_gamma"], zero_centered_gamma,
kwargs["epsilon"],
get_backward_sm_margin(), get_backward_sm_margin(),
) )
wkspace_aval = dx_aval.update( wkspace_aval = dx_aval.update(
...@@ -395,17 +595,14 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -395,17 +595,14 @@ class LayerNormBwdPrimitive(BasePrimitive):
""" """
LayerNorm bwd outer primitive abstract LayerNorm bwd outer primitive abstract
""" """
dx_aval, dgamma_aval, dbeta_aval, _ = LayerNormBwdPrimitive.abstract(*args, **kwargs) dx_aval, dgamma_aval, dbeta_aval, _ = NormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval, dbeta_aval return dx_aval, dgamma_aval, dbeta_aval
@staticmethod @staticmethod
def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma):
""" """
Layernorm bwd lowering rules bwd lowering rules
""" """
_, x_aval, _, _, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type) g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape g_shape = g_type.shape
b_type = ir.RankedTensorType(gamma.type) b_type = ir.RankedTensorType(gamma.type)
...@@ -413,1124 +610,644 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -413,1124 +610,644 @@ class LayerNormBwdPrimitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_backward_ffi"
sm_margin = get_backward_sm_margin() sm_margin = get_backward_sm_margin()
out = ffi.ffi_lowering(name)( return ffi.ffi_lowering(NormBwdPrimitive.name)(
ctx, ctx,
dz, dz,
x, x,
mu, mu,
rsigma, rsigma,
gamma, gamma,
norm_type=norm_type.value,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin, sm_margin=sm_margin,
) )
else:
dz_shape = ir.RankedTensorType(dz.type).shape
mu_shape = ir.RankedTensorType(mu.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
operands = [dz, mu, rsigma, x, gamma]
operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_backward_sm_margin()
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
zero_centered_gamma,
epsilon,
sm_margin,
)
out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
return out
@staticmethod @staticmethod
def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon): def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma):
assert LayerNormBwdPrimitive.inner_primitive is not None assert NormBwdPrimitive.inner_primitive is not None
dx, dgamma, dbeta, _ = LayerNormBwdPrimitive.inner_primitive.bind( dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon dz, x, mu, rsigma, gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma
) )
return dx, dgamma, dbeta return dx, dgamma, dbeta
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon): def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma):
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert LayerNormBwdPrimitive.outer_primitive is not None assert NormBwdPrimitive.outer_primitive is not None
dz, x, mu, rsigma, gamma = batched_args dz, x, mu, rsigma, gamma = batched_args
_, x_bdim, _, _, gamma_bdim = batch_dims _, x_bdim, _, _, gamma_bdim = batch_dims
out_bdims = x_bdim, gamma_bdim, gamma_bdim out_bdims = x_bdim, gamma_bdim, gamma_bdim
return ( return (
LayerNormBwdPrimitive.outer_primitive.bind( NormBwdPrimitive.outer_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon dz,
x,
mu,
rsigma,
gamma,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
), ),
out_bdims, out_bdims,
) )
@staticmethod @staticmethod
def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): def infer_sharding_from_operands(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos):
del zero_centered_gamma, epsilon, result_infos del norm_type, zero_centered_gamma, result_infos
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " f"Does not support to shard hidden dim in {NormBwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, " "Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance." "and hurt performance."
) )
g_b_spec = get_padded_spec(arg_infos[4]) g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None: if g_b_spec[-1] is not None:
warnings.warn( warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " f"{NormBwdPrimitive.name} does not support sharding of gradients "
"of gamma and beta of Layernorm " "of gamma and beta of "
"Enforcing no sharding of parameters hidden dim! " "Enforcing no sharding of parameters hidden dim! "
) )
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) dx_sharding = NamedSharding(
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None)) mesh, PartitionSpec(*x_spec[:-1], None), desc="NormBwdPrimitive.dx"
)
dgamma_sharding = dbeta_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormBwdPrimitive.dgamma"
)
return dx_sharding, dgamma_sharding, dbeta_sharding return dx_sharding, dgamma_sharding, dbeta_sharding
@staticmethod @staticmethod
def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): def partition(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos):
del result_infos del result_infos
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " f"Does not support to shard hidden dim in {NormBwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, " "Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance." "and hurt performance."
) )
g_b_spec = get_padded_spec(arg_infos[4]) g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None: if g_b_spec[-1] is not None:
warnings.warn( warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " f"{NormBwdPrimitive.name} does not support sharding of gradients "
"of gamma and beta of Layernorm " "of gamma and beta of "
"Enforcing no sharding of parameters hidden dim! " "Enforcing no sharding of parameters hidden dim! "
) )
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) dx_sharding = NamedSharding(
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None)) mesh, PartitionSpec(*x_spec[:-1], None), desc="NormBwdPrimitive.dx"
)
dgamma_sharding = dbeta_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormBwdPrimitive.dgamma"
)
out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(None))) rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1]), desc="NormBwdPrimitive.rsigma"
)
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormBwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormBwdPrimitive.mu")
arg_shardings = (
*x_shardings,
mu_sharding,
rsigma_sharding,
NamedSharding(mesh, PartitionSpec(None), desc="NormBwdPrimitive.gamma"),
)
def sharded_impl(dz, x, mu, rsigma, gamma): def sharded_impl(dz, x, mu, rsigma, gamma):
local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl( local_dx, local_dgamma, local_dbeta = NormBwdPrimitive.impl(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon dz,
x,
mu,
rsigma,
gamma,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
) )
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
if norm_type == NVTE_Norm_Type.LayerNorm:
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh) global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh)
else:
global_dbeta = local_dbeta
return local_dx, global_dgamma, global_dbeta return local_dx, global_dgamma, global_dbeta
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(LayerNormBwdPrimitive) register_primitive(NormBwdPrimitive)
def layernorm_bwd( def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None):
dz: jnp.ndarray,
x: jnp.ndarray,
mu: jnp.ndarray,
rsigma: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
):
""" """
Wrapper for TE layernorm bwd JAX native layernorm implementation
""" """
if not LayerNormBwdPrimitive.enabled(): x_ = jnp.asarray(x, jnp.float32)
_, vjp_func = jax.vjp( mean = jnp.mean(x_, axis=-1, keepdims=True)
partial(_jax_layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon), var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
x, rsigma = jax.lax.rsqrt(var + epsilon)
gamma, normed_input = (x_ - mean) * rsigma
beta, if zero_centered_gamma:
) gamma += 1.0
return vjp_func(dz) output = normed_input * gamma + beta
return LayerNormBwdPrimitive.outer_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
if quantizer:
ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else:
ln_out = jnp.asarray(output).astype(x.dtype)
class RmsNormFwdPrimitive(BasePrimitive): return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1)
"""
RMS Normalization Forward Primitive
"""
name = "te_rmsnorm_forward"
multiple_results = True
impl_static_args = (2,) # epsilon
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, gamma_aval, **kwargs):
"""
RMSNorm fwd inner primitive abstract
"""
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
rsigama_dtype = jnp.float32
out_aval = x_aval
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
False,
False,
kwargs["epsilon"],
get_forward_sm_margin(),
)
wkspace_aval = out_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out_aval, rsigma_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm fwd outer primitive abstract
"""
out_aval, rsigma_aval, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval
@staticmethod
def lowering(ctx, x, gamma, *, epsilon):
"""
RMSNorm fwd lowering rules
"""
if is_ffi_enabled():
name = "te_rmsnorm_forward_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
rsigma_element_type = ir.F32Type.get()
out_shape = x_shape
hidden_size = reduce(operator.mul, g_shape)
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type),
ir.RankedTensorType.get(batch_shape, rsigma_element_type),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [x, gamma]
operand_shapes = [x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(x, gamma, epsilon):
"""
to describe implementation
"""
assert RmsNormFwdPrimitive.inner_primitive is not None
out, rsigma, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
return out, rsigma
@staticmethod
def batcher(batched_args, batch_dims, *, epsilon):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims(batch_dims)
assert RmsNormFwdPrimitive.outer_primitive is not None
x, gamma = batched_args
x_bdim, _ = batch_dims
out_bdims = x_bdim, x_bdim
return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon), out_bdims
@staticmethod
def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
del epsilon, result_infos
x_spec = get_padded_spec(arg_infos[0])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
return (out_sharding, rsigma_sharding)
@staticmethod
def partition(epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec, g_spec = map(get_padded_spec, arg_infos)
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (x_sharding, g_sharding)
out_shardings = (out_sharding, rsigma_sharding)
impl = partial(RmsNormFwdPrimitive.impl, epsilon=epsilon)
return mesh, impl, out_shardings, arg_shardings
register_primitive(RmsNormFwdPrimitive)
def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
"""
Wrapper for TE rmsnorm fwd
"""
if not RmsNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + epsilon)
return _jax_rmsnorm(x, gamma, zero_centered_gamma=False, eps=epsilon), jnp.squeeze(
rsigma, axis=-1
)
return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)
class RmsNormBwdPrimitive(BasePrimitive):
"""
RMS Normalization Backward Primitive
"""
name = "te_rmsnorm_backward"
multiple_results = True
impl_static_args = (4,) # epsilon
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs):
"""
RMSNorm bwd inner primitive abstract
"""
w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)
assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
assert dz_aval.shape == x_aval.shape
assert rsigma_aval.shape == x_aval.shape[:-1]
assert rsigma_dtype == jnp.float32
dx_aval = dz_aval
dgamma_aval = gamma_aval
(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
False,
False,
kwargs["epsilon"],
get_backward_sm_margin(),
)
wkspace_aval = dx_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return dx_aval, dgamma_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval
@staticmethod
def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
"""
RMSNorm bwd lowering rules
"""
if is_ffi_enabled():
name = "te_rmsnorm_backward_ffi"
sm_margin = get_backward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
_, x_aval, _, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
dz_shape = ir.RankedTensorType(dz.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [dz, rsigma, x, gamma]
operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_backward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(dz, x, rsigma, gamma, epsilon):
assert RmsNormBwdPrimitive.inner_primitive is not None
dx, dgamma, _ = RmsNormBwdPrimitive.inner_primitive.bind(
dz, x, rsigma, gamma, epsilon=epsilon
)
return dx, dgamma
@staticmethod
def batcher(batched_args, batch_dims, *, epsilon):
check_valid_batch_dims(batch_dims)
assert RmsNormBwdPrimitive.outer_primitive is not None
dz, x, rsigma, gamma = batched_args
_, x_bdim, _, gamma_bdim = batch_dims
out_bdims = x_bdim, gamma_bdim
return (
RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
del epsilon, result_infos
x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
return dx_sharding, dgamma_sharding
@staticmethod
def partition(epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
out_shardings = dx_sharding, dgamma_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(None)))
def sharded_impl(dz, x, rsigma, gamma):
local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
return local_dx, global_dgamma
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(RmsNormBwdPrimitive)
def rmsnorm_bwd(
dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray, epsilon: float
):
"""
Wrapper for TE layernorm bwd
"""
if not RmsNormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_rmsnorm, zero_centered_gamma=False, eps=epsilon), x, gamma
)
return vjp_func(dz)
return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
class LayerNormFwdFp8Primitive(BasePrimitive):
"""
Layer Normalization Forward FP8 Primitive
"""
name = "te_layernorm_forward_fp8"
multiple_results = True
impl_static_args = (6, 7, 8) # out_type, zero_centered_gamma, epsilon
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
gamma_aval,
beta_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
zero_centered_gamma,
epsilon,
):
"""
LayerNorm fwd (fp8 out) inner primitive abstract
"""
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
mu_rsigama_dtype = jnp.float32
assert gamma_aval.size == beta_aval.size
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in type
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type
jax_dtype_to_te_dtype(out_dtype),
True,
zero_centered_gamma,
epsilon,
get_forward_sm_margin(),
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, mu_aval, rsigma_aval, updated_amax_aval, _ = LayerNormFwdFp8Primitive.abstract(
*args, **kwargs
)
return out_aval, mu_aval, rsigma_aval, updated_amax_aval
@staticmethod
def lowering(
ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma, epsilon
):
"""
LayerNorm fwd (fp8 out) lowering rules
"""
x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gamma_aval.dtype == beta_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
b_type = ir.RankedTensorType(beta.type)
b_shape = b_type.shape
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name, operand_output_aliases={3: 3})(
ctx,
x,
gamma,
beta,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get()
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_shape = x_shape
hidden_size = reduce(operator.mul, g_shape)
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [
x_shape,
g_shape,
b_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor( def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
batch_size, """
hidden_size, JAX native rmsnorm implementation
wkspace_aval.size, """
jax_dtype_to_te_dtype(x_aval.dtype), x_ = jnp.asarray(x, jnp.float32)
jax_dtype_to_te_dtype(gamma_aval.dtype), var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
jax_dtype_to_te_dtype(wkspace_aval.dtype), rsigma = jax.lax.rsqrt(var + epsilon)
zero_centered_gamma, normed_input = x_ * rsigma
epsilon, if zero_centered_gamma:
sm_margin, gamma += 1.0
) output = normed_input * gamma
out = custom_caller( if quantizer:
LayerNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={3: 3} ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
) else:
ln_out = jnp.asarray(output).astype(x.dtype)
return out return ln_out, jnp.squeeze(rsigma, axis=-1)
@staticmethod
def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, epsilon): def layernorm_fwd(
""" x: jnp.ndarray,
to describe implementation gamma: jnp.ndarray,
""" beta: jnp.ndarray,
assert LayerNormFwdFp8Primitive.inner_primitive is not None zero_centered_gamma: bool,
out, mu, rsigma, updated_amax, _ = LayerNormFwdFp8Primitive.inner_primitive.bind( epsilon: float,
quantizer: Optional[Quantizer],
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]:
"""Layer normalization forward pass with optional quantization.
Args:
x: Input tensor to be normalized.
Shape: (..., K) where K is the hidden size.
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
A tuple containing:
- If quantizer is None:
The normalized input tensor. Shape: (..., K)
If quantizer is provided:
A ScaledTensor containing the quantized normalized input.
- Mean of the input tensor. Shape: (..., 1)
- Reciprocal of the standard deviation of the input tensor. Shape: (..., 1)
"""
if not NormFwdPrimitive.enabled():
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = (
quantizer.scale
if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32)
)
if quantizer is None:
output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind(
x, x,
scale,
gamma, gamma,
beta, beta,
amax, norm_type=NVTE_Norm_Type.LayerNorm,
scale,
scale_inv,
out_dtype=out_dtype,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
) out_dtype=x.dtype,
return out, mu, rsigma, updated_amax scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
is_2x=False,
@staticmethod scale_dtype=jnp.float32,
def batcher(batched_args, batch_dims, *, out_dtype, zero_centered_gamma, epsilon): scale_shapes=((1,), (1,)),
""" is_outer=True,
to describe batch rules for vmap )
""" return output, mu, rsigma
check_valid_batch_dims(batch_dims)
assert LayerNormFwdFp8Primitive.outer_primitive is not None is_2x2x = quantizer.is_2x2x()
x, gamma, beta, amax, scale, scale_inv = batched_args # TE/common normalization doesn't support 2x delayed scaling
x_bdim, _, _, amax_bdim, _, _ = batch_dims if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
is_2x2x = False
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim (
return ( rowwise_casted_output,
LayerNormFwdFp8Primitive.outer_primitive.bind( colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
mu,
rsigma,
) = NormFwdPrimitive.outer_primitive.bind(
x, x,
scale,
gamma, gamma,
beta, beta,
amax, norm_type=NVTE_Norm_Type.LayerNorm,
scale,
scale_inv,
out_dtype=out_dtype,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
), out_dtype=quantizer.q_dtype,
out_bdims, scaling_mode=quantizer.scaling_mode,
) is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
@staticmethod scale_shapes=quantizer.get_scale_shapes(x.shape),
def infer_sharding_from_operands( is_outer=True,
out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos )
): quantizer.update(updated_amax)
del out_dtype, zero_centered_gamma, epsilon, result_infos
x_spec = get_padded_spec(arg_infos[0]) # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if x_spec[-1] is not None: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
warnings.warn( colwise_casted_output = jnp.transpose(
f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
"Force to not shard the hidden dim, which might introduce extra collective ops, " )
"and hurt performance." colwise_scale_inv = rowwise_scale_inv
)
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False
)
rowwise_scale_inv = rowwise_scale_inv.flatten()[
: reduce(operator.mul, rowwise_unpadded_shape)
].reshape(rowwise_unpadded_shape)
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_unpadded_shape)
].reshape(colwise_unpadded_shape)
scaled_tensor = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
)
return scaled_tensor, mu, rsigma
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
return (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)
@staticmethod def layernorm_bwd(
def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): dz: jnp.ndarray,
del result_infos x: jnp.ndarray,
x_spec = get_padded_spec(arg_infos[0]) mu: jnp.ndarray,
g_spec = get_padded_spec(arg_infos[1]) rsigma: jnp.ndarray,
b_spec = get_padded_spec(arg_infos[2]) gamma: jnp.ndarray,
if x_spec[-1] is not None: beta: jnp.ndarray,
warnings.warn( zero_centered_gamma: bool,
f"Does not support to shard hidden dim in {LayerNormFwdFp8Primitive.name}! " epsilon: float,
"Force to not shard the hidden dim, which might introduce extra collective ops, " ):
"and hurt performance." """Layer normalization backward pass.
)
if g_spec[-1] is not None: Args:
warnings.warn( dz: Gradient of the output with respect to the normalized output.
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter gamma " Shape: (..., K) where K is the hidden size.
"Enforcing no sharding of parameters hidden dim! " x: Input tensor that was normalized in the forward pass.
) Shape: (..., K)
if b_spec[-1] is not None: mu: Mean of the input tensor from the forward pass.
warnings.warn( Shape: (..., 1)
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter beta " rsigma: Reciprocal of the standard deviation from the forward pass.
"Enforcing no sharding of parameters hidden dim! " Shape: (..., 1)
) gamma: Scale parameter for normalization.
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) Shape: (K,)
g_sharding = NamedSharding(mesh, PartitionSpec(None)) beta: Bias parameter for normalization.
b_sharding = NamedSharding(mesh, PartitionSpec(None)) Shape: (K,)
out_sharding = x_sharding zero_centered_gamma: If True, gamma is zero-centered.
mu_sharding = rsigma_sharding = NamedSharding( epsilon: Small constant for numerical stability.
mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1])
) Returns:
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3]))) A tuple containing:
fp8_meta_sharding = amax_sharding - Gradient of the input tensor.
arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3 Shape: (..., K)
out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding) - Gradient of the scale parameter (gamma).
Shape: (K,)
def sharded_impl(x, gamma, beta, amax, scale, scale_inv): - Gradient of the bias parameter (beta).
local_x, local_mu, local_rsigma, local_amax = LayerNormFwdFp8Primitive.impl( Shape: (K,)
"""
if not NormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_layernorm, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon),
x, x,
gamma, gamma,
beta, beta,
amax, )
scale, mu_empty = jnp.zeros(mu.shape, mu.dtype)
scale_inv, rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
out_dtype=out_dtype, return vjp_func((dz, mu_empty, rsigma_empty))
return NormBwdPrimitive.outer_primitive.bind(
dz,
x,
mu,
rsigma,
gamma,
norm_type=NVTE_Norm_Type.LayerNorm,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
) )
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, local_mu, local_rsigma, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings def rmsnorm_fwd(
register_primitive(LayerNormFwdFp8Primitive)
def layernorm_fwd_fp8(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: jnp.dtype,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
): quantizer: Optional[Quantizer],
""" ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]:
Wrapper for TE layernorm fwd (fp8 out) """Root mean square normalization forward pass with optional quantization.
"""
if not LayerNormFwdFp8Primitive.enabled(): Args:
return _jax_layernorm_fp8( x: Input tensor to be normalized.
Shape: (..., K) where K is the hidden size.
gamma: Scale parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
A tuple containing:
- If quantizer is None:
The normalized input tensor.
Shape: (..., K)
If quantizer is provided:
A ScaledTensor containing the quantized normalized input.
- Reciprocal of the root mean square of the input tensor.
Shape: (..., 1)
"""
if not NormFwdPrimitive.enabled():
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = (
quantizer.scale
if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32)
)
beta = jnp.ones((1,), dtype=jnp.float32)
if quantizer is None:
output, _, _, _, _, _, rsigma = NormFwdPrimitive.outer_primitive.bind(
x, x,
scale,
gamma, gamma,
beta, beta,
scale, norm_type=NVTE_Norm_Type.RMSNorm,
amax,
out_dtype=out_dtype,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
eps=epsilon, epsilon=epsilon,
) out_dtype=x.dtype,
return LayerNormFwdFp8Primitive.outer_primitive.bind( scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True,
)
return output, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
is_2x2x = False
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
_,
rsigma,
) = NormFwdPrimitive.outer_primitive.bind(
x, x,
scale,
gamma, gamma,
beta, beta,
amax, norm_type=NVTE_Norm_Type.RMSNorm,
scale,
scale_inv,
out_dtype=out_dtype,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
) out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode,
is_2x=is_2x2x,
class RmsNormFwdFp8Primitive(BasePrimitive): scale_dtype=quantizer.get_scale_dtype(),
""" scale_shapes=quantizer.get_scale_shapes(x.shape),
RMS Normalization Forward FP8 Primitive is_outer=True,
""" )
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
colwise_scale_inv = rowwise_scale_inv
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False
)
rowwise_scale_inv = rowwise_scale_inv.flatten()[
: reduce(operator.mul, rowwise_unpadded_shape)
].reshape(rowwise_unpadded_shape)
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_unpadded_shape)
].reshape(colwise_unpadded_shape)
scaled_tensor = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
)
return scaled_tensor, rsigma
name = "te_rmsnorm_forward_fp8"
multiple_results = True
impl_static_args = (5, 6) # out_dtype, epsilon
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon):
"""
RMSNorm fwd (fp8 out) inner primitive abstract
"""
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
rsigama_dtype = jnp.float32
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch_size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(out_dtype), # out te_dtype
False,
False,
epsilon,
get_forward_sm_margin(),
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) def rmsnorm_bwd(
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) dz: jnp.ndarray,
amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) x: jnp.ndarray,
wkspace_aval = x_aval.update( rsigma: jnp.ndarray,
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) gamma: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
):
"""Root mean square normalization backward pass.
Args:
dz: Gradient of the output with respect to the normalized output.
Shape: (..., K) where K is the hidden size.
x: Input tensor that was normalized in the forward pass.
Shape: (..., K)
rsigma: Reciprocal of the root mean square from the forward pass.
Shape: (..., 1)
gamma: Scale parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
Returns:
A tuple containing:
- Gradient of the input tensor.
Shape: (..., K)
- Gradient of the scale parameter (gamma).
Shape: (K,)
"""
if not NormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_rmsnorm, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon),
x,
gamma,
) )
rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
return out_aval, rsigma_aval, amax_aval, wkspace_aval return vjp_func((dz, rsigma_empty))
mu = jnp.empty(())
@staticmethod dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind(
def outer_abstract(*args, **kwargs): dz,
"""
RMSNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, rsigma_aval, amax_aval, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval, amax_aval
@staticmethod
def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
"""
RMSNorm fwd (fp8 out) lowering rules
"""
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
if is_ffi_enabled():
name = "te_rmsnorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx,
x, x,
mu,
rsigma,
gamma, gamma,
amax, norm_type=NVTE_Norm_Type.RMSNorm,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_rsigma_dtype = ir.F32Type.get()
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_shape = x_shape
hidden_size = reduce(operator.mul, g_shape)
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
) )
return (dx, dgamma)
out = custom_caller(
RmsNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={2: 2}
)
return out
@staticmethod
def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon):
"""
to describe implementation
"""
assert RmsNormFwdFp8Primitive.inner_primitive is not None
out, rsigma, amax, _ = RmsNormFwdFp8Primitive.inner_primitive.bind(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
return out, rsigma, amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, epsilon):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims(batch_dims)
assert RmsNormFwdFp8Primitive.outer_primitive is not None
x, gamma, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim
return (
RmsNormFwdFp8Primitive.outer_primitive.bind(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_infos):
del out_dtype, epsilon, result_infos
x_spec = get_padded_spec(arg_infos[0])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, rsigma_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdFp8Primitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
fp8_meta_sharding = amax_sharding
arg_shardings = (x_sharding, g_sharding) + (fp8_meta_sharding,) * 3
out_shardings = (out_sharding, rsigma_sharding, amax_sharding)
def sharded_impl(x, gamma, amax, scale, scale_inv):
local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, local_rsigma, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
def normalization_fwd(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
norm_type: str,
quantizer: Optional[Quantizer],
):
"""Common wrapper for normalization forward pass.
Args:
x: Input tensor to be normalized.
Shape: (..., K) where K is the hidden size.
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
norm_type: Type of normalization to apply. Must be one of:
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
A tuple containing:
- If quantizer is None:
The normalized input tensor.
Shape: (..., K)
If quantizer is provided:
A ScaledTensor containing the quantized normalized input.
- Mean of the input tensor (None for RMSNorm).
Shape: (..., 1)
- Reciprocal of the standard deviation (or root mean square for RMSNorm).
Shape: (..., 1)
Note:
zero_centered_gamma is not supported if norm_type is 'rmsnorm'.
"""
if norm_type == "layernorm":
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
elif norm_type == "rmsnorm":
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer)
mu = None
else:
raise ValueError(f"{norm_type=} is not supported.")
register_primitive(RmsNormFwdFp8Primitive) return output, mu, rsigma
def rmsnorm_fwd_fp8( def normalization_bwd(
dz: jnp.ndarray,
x: jnp.ndarray, x: jnp.ndarray,
mu: jnp.ndarray,
rsigma: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
amax: jnp.ndarray, beta: jnp.ndarray,
scale: jnp.ndarray, zero_centered_gamma: bool,
scale_inv: jnp.ndarray,
out_dtype: jnp.dtype,
epsilon: float, epsilon: float,
norm_type: str,
): ):
""" """Common wrapper for normalization backward pass.
Wrapper for TE rmsnorm fwd (fp8 out)
""" Args:
if not RmsNormFwdFp8Primitive.enabled(): dz: Gradient of the output with respect to the normalized output.
return _jax_rmsnorm_fp8( Shape: (..., K) where K is the hidden size.
x, gamma, scale, amax, out_dtype=out_dtype, zero_centered_gamma=False, eps=epsilon x: Input tensor that was normalized in the forward pass.
) Shape: (..., K)
return RmsNormFwdFp8Primitive.outer_primitive.bind( mu: Mean of the input tensor from the forward pass (None for RMSNorm).
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon Shape: (..., 1)
) rsigma: Reciprocal of the standard deviation (or root mean square) from the forward pass.
Shape: (..., 1)
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
norm_type: Type of normalization used in the forward pass. Must be one of:
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
Returns:
A tuple containing:
- Gradient of the input tensor.
Shape: (..., K)
- Gradient of the scale parameter (gamma).
Shape: (K,)
- Gradient of the bias parameter (beta) (None for RMSNorm).
Shape: (K,)
Note:
zero_centered_gamma is not supported if norm_type is 'rmsnorm'.
"""
if norm_type == "layernorm":
dx, dgamma, dbeta = layernorm_bwd(
dz, x, mu, rsigma, gamma, beta, zero_centered_gamma, epsilon
)
elif norm_type == "rmsnorm":
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
dx, dgamma = rmsnorm_bwd(dz, x, rsigma, gamma, zero_centered_gamma, epsilon)
dbeta = None
else:
raise ValueError(f"{norm_type=} is not supported.")
return dx, dgamma, dbeta
...@@ -2,28 +2,29 @@ ...@@ -2,28 +2,29 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for quantization""" """JAX/TE custom ops for quantization"""
from typing import Tuple from typing import Tuple, Optional
from packaging import version from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec
from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import ( from .misc import (
get_padded_spec, get_padded_spec,
check_valid_batch_dims, check_valid_batch_dims,
te_dtype_to_jax_dtype,
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype, multidim_transpose,
is_ffi_enabled, should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding,
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeAxis, DelayedScaleQuantizer, ScalingMode
if version.parse(jax.__version__) >= version.parse("0.5.0"): if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports from jax import ffi # pylint: disable=ungrouped-imports
...@@ -31,166 +32,591 @@ else: ...@@ -31,166 +32,591 @@ else:
from jax.extend import ffi # pylint: disable=ungrouped-imports from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["cast_fp8"] __all__ = ["quantize", "quantize_dbias"]
def _jax_quantize(x, scale, q_dtype): class DBiasQuantizePrimitive(BasePrimitive):
""" """
Quantize with scale Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
""" """
compute_dtype = scale.dtype
dtype_max = (jnp.finfo(q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype)
name = "te_dbias_quantize_ffi"
def _jax_cast_fp8(inputs, scale, amax, out_dtype):
"""
JAX native fp8 casting implementation
"""
casted_output = _jax_quantize(inputs, scale, q_dtype=out_dtype)
updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype))
return casted_output, updated_amax
class CastFP8Primitive(BasePrimitive):
"""
Cast Primitive
"""
name = "te_quantize"
multiple_results = True multiple_results = True
impl_static_args = (4,) impl_static_args = (
2,
3,
4,
5,
6,
7,
8,
) # out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): def abstract(
x_aval,
scale_aval,
*,
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
""" """
te_cast abstract te_dbias_quantize_p abstract
""" """
del scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) rowwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return casted_x_aval, updated_amax_aval if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
rowwise_out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
@staticmethod updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
""" rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
te_cast lowering rules scaling_mode
""" ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_quantize_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
ctx, x, amax, scale, scale_inv
)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_types = [ scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor( colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype) colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
t_shape = multidim_transpose(x_aval.shape)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
# Don't transpose output for MXFP8
t_shape = x_aval.shape
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
) )
out = custom_caller( if is_dbias:
CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} gi_hidden_size = x_aval.shape[-1]
dbias_shape = (gi_hidden_size,)
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
(wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
) )
return out return (
rowwise_out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
dbias_aval,
wkspace_aval,
)
@staticmethod @staticmethod
def impl(x, amax, scale, scale_inv, out_dtype): def outer_abstract(*args, **kwargs):
""" """
te_cast implementation te_dbias_quantize_p outer primitive abstract
""" """
assert CastFP8Primitive.inner_primitive is not None (
casted_x, updated_amax = CastFP8Primitive.inner_primitive.bind( out,
x, amax, scale, scale_inv, out_dtype=out_dtype colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
_,
) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
@staticmethod
def lowering(
ctx,
x,
scale,
*,
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
te_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, is_outer
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(DBiasQuantizePrimitive.name)(
ctx,
x,
scale,
scaling_mode=scaling_mode,
q_axis=q_axis,
is_dbias=is_dbias,
) )
return casted_x, updated_amax
@staticmethod @staticmethod
def batcher(batched_args, batch_dims, *, out_dtype): def impl(
check_valid_batch_dims(batch_dims) x,
assert CastFP8Primitive.outer_primitive is not None scale,
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
te_dbias_quantize_p implementation
"""
del is_outer
assert DBiasQuantizePrimitive.inner_primitive is not None
(
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
_,
) = DBiasQuantizePrimitive.inner_primitive.bind(
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
is_outer=False,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) # Exclude wkspace
x, amax, scale, scale_inv = batched_args @staticmethod
x_bdim, amax_bdim, *_ = batch_dims def batcher(
batched_args,
batch_dims,
*,
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
to describe batch rules for vmap
"""
del is_outer
check_valid_batch_dims(batch_dims)
assert DBiasQuantizePrimitive.outer_primitive is not None
x, scale = batched_args
x_bdim, scale_bdim = batch_dims
amax_bdim = scale_bdim
out_bdims = x_bdim, amax_bdim out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
return ( return (
CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype), DBiasQuantizePrimitive.outer_primitive.bind(
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
),
out_bdims, out_bdims,
) )
@staticmethod @staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): def infer_sharding_from_operands(
del out_dtype, result_infos out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
arg_infos,
result_infos,
):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_dbias, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) out_sharding = NamedSharding(
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) mesh,
return (casted_x_sharding, amax_sharding) PartitionSpec(*x_spec[:-1], x_spec[-1]),
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec)
else:
colwise_out_spec = x_spec
else:
colwise_out_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
)
scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])),
desc="DBiasQuantizePrimitive.scale_inv",
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DBiasQuantizePrimitive.dbias_sharding",
)
return (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_sharding,
)
@staticmethod @staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos): def partition(
del result_infos out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) out_sharding = NamedSharding(
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]),
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec)
else:
colwise_out_spec = x_spec
else:
colwise_out_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
)
scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])),
desc="DBiasQuantizePrimitive.scale_inv",
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DBiasQuantizePrimitive.dbias_sharding",
)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, amax_sharding) out_shardings = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_sharding,
)
def sharded_impl(x, amax, scale, scale_inv): def sharded_impl(x, scale):
local_cx, local_updated_amax = CastFP8Primitive.impl( (
x, amax, scale, scale_inv, out_dtype=out_dtype local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
local_amax,
local_dbias,
) = DBiasQuantizePrimitive.impl(
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
is_outer=True,
) )
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
return local_cx, global_updated_amax if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
if is_dbias:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
else:
global_dbias = local_dbias
return (
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
global_updated_amax,
global_dbias,
)
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(CastFP8Primitive) register_primitive(DBiasQuantizePrimitive)
def cast_fp8( def _jax_quantize(x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None):
if quantizer is None:
return x
return quantizer.quantize(x, dq_dtype=dq_dtype)
def _jax_dbias(dx: jnp.ndarray):
dbias = jnp.sum(
dx,
axis=tuple(range(dx.ndim - 1)),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias
def _jax_quantize_dbias(
x,
quantizer: Quantizer = None,
dq_dtype: Optional[jnp.dtype] = None,
):
if quantizer is None:
return x, None
return quantizer.quantize(x, dq_dtype=dq_dtype), _jax_dbias(x)
def _jax_dbias(
dx: jnp.ndarray,
):
dbias = jnp.sum(
dx.astype(jnp.float32),
axis=tuple(range(dx.ndim - 1)),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias.astype(dx.dtype)
def _quantize_impl(
x: jnp.ndarray, x: jnp.ndarray,
amax: jnp.ndarray, quantizer: Quantizer,
scale: jnp.ndarray, is_dbias: bool = False,
scale_inv: jnp.ndarray, dq_dtype: Optional[jnp.dtype] = None,
out_dtype: TEDType, ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
Return FP8 tensor Return FP8 tensor
""" """
if not CastFP8Primitive.enabled(): assert (dq_dtype is None) or (
return _jax_cast_fp8(x, scale, amax, out_dtype=out_dtype) quantizer is not None
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) ), "quantizer must be provided if dq_dtype is provided"
if not DBiasQuantizePrimitive.enabled():
if is_dbias:
return _jax_quantize_dbias(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
)
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None
# TE/common doesn't support colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if is_dbias:
return _jax_quantize_dbias(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
)
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None
scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_impl(
x=x,
is_dbias=False,
quantizer=quantizer,
dq_dtype=dq_dtype,
)
dbias = _jax_dbias(x)
return out, dbias
if quantizer is None:
if is_dbias:
return x, _jax_dbias(x)
return x, None
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) = DBiasQuantizePrimitive.outer_primitive.bind(
x,
scale,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_axis=quantizer.q_axis.value,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_dbias=is_dbias,
is_outer=True,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
out = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=dq_dtype if dq_dtype is not None else x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
)
return out, dbias
# TODO(Phuong): do not expose dq_dtype to users
def quantize(
x: jnp.ndarray,
quantizer: Quantizer,
dq_dtype: Optional[jnp.dtype] = None,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
Args:
x: Input tensor to be quantized.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
Returns:
A ScaledTensor containing the quantized input tensor.
"""
out, _ = _quantize_impl(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
)
return out
# TODO(Phuong): do not expose dq_dtype to users
def quantize_dbias(
dz: jnp.ndarray,
quantizer: Quantizer,
is_dbias: bool = True,
dq_dtype: Optional[jnp.dtype] = None,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
Args:
dz: Input tensor to be quantized and used for bias gradient computation.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
is_dbias: If True, compute bias gradient. Defaults to True.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
Returns:
A tuple containing:
- A ScaledTensor containing the quantized input tensor.
The ScaledTensor includes both the quantized data and scaling factors.
- The bias gradient tensor.
Shape: (K,) or empty if is_dbias is False.
"""
return _quantize_impl(
dz,
quantizer=quantizer,
is_dbias=is_dbias,
dq_dtype=dq_dtype,
)
...@@ -11,14 +11,10 @@ from packaging import version ...@@ -11,14 +11,10 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .misc import get_padded_spec, check_valid_batch_dims
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
if version.parse(jax.__version__) >= version.parse("0.5.0"): if version.parse(jax.__version__) >= version.parse("0.5.0"):
...@@ -38,30 +34,6 @@ __all__ = [ ...@@ -38,30 +34,6 @@ __all__ = [
] ]
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
if mask is not None:
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def is_softmax_kernel_available( def is_softmax_kernel_available(
softmax_type: SoftmaxType, softmax_type: SoftmaxType,
batch: int, batch: int,
...@@ -139,38 +111,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -139,38 +111,7 @@ class SoftmaxPrimitive(BasePrimitive):
""" """
softmax_forward lowering rules softmax_forward lowering rules
""" """
if is_ffi_enabled(): return ffi.ffi_lowering(name)(ctx, logits, scale_factor=scale_factor)
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor)
else:
(i_aval,) = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
pad_batch = batch
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits]
operand_shapes = [i_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(i_aval.dtype),
scale_factor,
)
out = custom_caller(name, args, opaque, False)
return out
@staticmethod @staticmethod
def forward_impl(primitive, logits, scale_factor): def forward_impl(primitive, logits, scale_factor):
...@@ -250,43 +191,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -250,43 +191,7 @@ class SoftmaxPrimitive(BasePrimitive):
""" """
softmax_backward lowering rules softmax_backward lowering rules
""" """
if is_ffi_enabled(): return ffi.ffi_lowering(name)(ctx, dz, softmax_out, scale_factor=scale_factor)
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor)
else:
dz_aval, _ = ctx.avals_in
dz_type = ir.RankedTensorType(dz.type)
dz_shape = dz_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, dz_shape[:-3])
pad_batch = batch # unused
heads = dz_shape[-3]
q_seqlen = dz_shape[-2]
k_seqlen = dz_shape[-1]
softmax_out_type = ir.RankedTensorType(softmax_out.type)
softmax_out_shape = softmax_out_type.shape
out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
operands = [dz, softmax_out]
operand_shapes = [dz_shape, softmax_out_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(dz_aval.dtype),
scale_factor,
)
out = custom_caller(name, args, opaque, False)
return out
@staticmethod @staticmethod
def backward_impl(primitive, dz, softmax_out, scale_factor): def backward_impl(primitive, dz, softmax_out, scale_factor):
...@@ -356,7 +261,7 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -356,7 +261,7 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
Scaled Softmax Fwd Primitive Scaled Softmax Fwd Primitive
""" """
name = "te_scaled_softmax_forward" name = "te_scaled_softmax_forward_ffi"
multiple_results = False multiple_results = False
impl_static_args = (1,) # scale_factor impl_static_args = (1,) # scale_factor
inner_primitive = None inner_primitive = None
...@@ -429,22 +334,12 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -429,22 +334,12 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledSoftmaxFwdPrimitive) register_primitive(ScaledSoftmaxFwdPrimitive)
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
""" """
Scaled Softmax Bwd Primitive Scaled Softmax Bwd Primitive
""" """
name = "te_scaled_softmax_backward" name = "te_scaled_softmax_backward_ffi"
multiple_results = False multiple_results = False
impl_static_args = (2,) # scale_factor impl_static_args = (2,) # scale_factor
inner_primitive = None inner_primitive = None
...@@ -530,7 +425,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -530,7 +425,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
Scaled Masked Softmax Fwd Primitive Scaled Masked Softmax Fwd Primitive
""" """
name = "te_scaled_masked_softmax_forward" name = "te_scaled_masked_softmax_forward_ffi"
multiple_results = False multiple_results = False
impl_static_args = (2,) # scale_factor impl_static_args = (2,) # scale_factor
inner_primitive = None inner_primitive = None
...@@ -591,42 +486,10 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -591,42 +486,10 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
""" """
te_scaled_masked_softmax_forward lowering rules te_scaled_masked_softmax_forward lowering rules
""" """
if is_ffi_enabled(): return ffi.ffi_lowering(ScaledMaskedSoftmaxFwdPrimitive.name)(
ffi_name = "te_scaled_masked_softmax_forward_ffi" ctx, logits, mask, scale_factor=scale_factor
out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor)
else:
logits_aval, _ = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]
mask_type = ir.RankedTensorType(mask.type)
mask_shape = mask_type.shape
pad_batch = reduce(operator.mul, mask_shape[:-3])
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits, mask]
operand_shapes = [i_shape, mask_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(logits_aval.dtype),
scale_factor,
) )
out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
return out
@staticmethod @staticmethod
def impl(logits, mask, scale_factor): def impl(logits, mask, scale_factor):
assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
...@@ -666,26 +529,12 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -666,26 +529,12 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledMaskedSoftmaxFwdPrimitive) register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
def scaled_masked_softmax_fwd(
logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
"""
scaled_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor
)
class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
""" """
Scaled Masked Softmax Bwd Primitive Scaled Masked Softmax Bwd Primitive
""" """
name = "te_scaled_masked_softmax_backward" name = "te_scaled_masked_softmax_backward_ffi"
multiple_results = False multiple_results = False
impl_static_args = (2,) # scale_factor impl_static_args = (2,) # scale_factor
inner_primitive = None inner_primitive = None
...@@ -712,12 +561,10 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -712,12 +561,10 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
""" """
te_scaled_upper_triang_masked_backward lowering rules te_scaled_upper_triang_masked_backward lowering rules
""" """
out = SoftmaxPrimitive.backward_lowering( return SoftmaxPrimitive.backward_lowering(
ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
) )
return out
@staticmethod @staticmethod
def impl(dz, softmax_out, scale_factor): def impl(dz, softmax_out, scale_factor):
return SoftmaxPrimitive.backward_impl( return SoftmaxPrimitive.backward_impl(
...@@ -753,33 +600,12 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -753,33 +600,12 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledMaskedSoftmaxBwdPrimitive) register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
def scaled_masked_softmax_bwd(
dz: jnp.ndarray,
softmax_out: jnp.ndarray,
logits: jnp.ndarray,
mask: jnp.ndarray,
scale_factor: float,
) -> jnp.ndarray:
"""
scaled_masked_backward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
)
return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor
)
class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
""" """
Scaled Upper Triang Masked Softmax Fwd Primitive Scaled Upper Triang Masked Softmax Fwd Primitive
""" """
name = "te_scaled_upper_triang_masked_softmax_forward" name = "te_scaled_upper_triang_masked_softmax_forward_ffi"
multiple_results = False multiple_results = False
impl_static_args = (1,) # scale_factor impl_static_args = (1,) # scale_factor
inner_primitive = None inner_primitive = None
...@@ -860,24 +686,12 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -860,24 +686,12 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_upper_triang_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor
)
class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
""" """
Scaled Upper Triang Masked Softmax Bwd Primitive Scaled Upper Triang Masked Softmax Bwd Primitive
""" """
name = "te_scaled_upper_triang_masked_softmax_backward" name = "te_scaled_upper_triang_masked_softmax_backward_ffi"
multiple_results = False multiple_results = False
impl_static_args = (2,) # scale_factor impl_static_args = (2,) # scale_factor
inner_primitive = None inner_primitive = None
...@@ -904,7 +718,7 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -904,7 +718,7 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
""" """
te_scaled_upper_triang_masked_backward lowering rules te_scaled_upper_triang_masked_backward lowering rules
""" """
out = SoftmaxPrimitive.backward_lowering( return SoftmaxPrimitive.backward_lowering(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
ctx, ctx,
dz, dz,
...@@ -912,8 +726,6 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -912,8 +726,6 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
scale_factor=scale_factor, scale_factor=scale_factor,
) )
return out
@staticmethod @staticmethod
def impl(dz, softmax_out, scale_factor): def impl(dz, softmax_out, scale_factor):
return SoftmaxPrimitive.backward_impl( return SoftmaxPrimitive.backward_impl(
...@@ -953,6 +765,87 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -953,6 +765,87 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
if mask is not None:
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
def scaled_masked_softmax_fwd(
logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
"""
scaled_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor
)
def scaled_masked_softmax_bwd(
dz: jnp.ndarray,
softmax_out: jnp.ndarray,
logits: jnp.ndarray,
mask: jnp.ndarray,
scale_factor: float,
) -> jnp.ndarray:
"""
scaled_masked_backward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
)
return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor
)
def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_upper_triang_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor
)
def scaled_upper_triang_masked_softmax_bwd( def scaled_upper_triang_masked_softmax_bwd(
dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
) -> jnp.ndarray: ) -> jnp.ndarray:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for transpose"""
import operator
from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
check_valid_batch_dims,
jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype,
get_padded_spec,
multidim_transpose,
normalize_axis_boundary,
is_ffi_enabled,
)
from .activation import ActivationEnum
from .activation import _jax_act_lu
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"transpose",
"cast_transpose",
"dbias_cast_transpose",
"dact_lu_dbias_cast_transpose",
"dgated_act_lu_cast_transpose",
]
def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary):
"""
JAX native transpose implementation
"""
axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary)
return jnp.transpose(inputs, axes=axes)
def _jax_cast_transpose(
inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
JAX native cast_transpose implementation
"""
casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype)
casted_transposed_output = _jax_transpose(
casted_output, static_axis_boundary, transpose_axis_boundary
)
return casted_output, casted_transposed_output, updated_amax
def _jax_dbias_cast_transpose(
dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
JAX native dbias_cast_transpose implementation
"""
casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
dz,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.sum(
dz,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dz.ndim
)
),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return casted_dz, cast_transposed_dz, dbias, updated_amax
class TransposePrimitive(BasePrimitive):
"""
Transpose Primitive
"""
name = "te_transpose"
multiple_results = False
impl_static_args = (1, 2)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
"""
_transpose abstract
"""
transposed_x_shape = multidim_transpose(
x_aval.shape, static_axis_boundary, transpose_axis_boundary
)
xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)
return xt_aval
@staticmethod
def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
"""
_transpose cuda lowering
"""
x_aval = ctx.avals_in[0]
assert x_aval.dtype in [
jnp.float32,
jnp.float16,
jnp.bfloat16,
jnp.float8_e4m3fn,
jnp.float8_e5m2,
]
if is_ffi_enabled():
name = "te_transpose_ffi"
out = ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
if static_axis_boundary >= 0:
for i in range(static_axis_boundary + 1):
assert ir_x_shape[i] == 1
transposed_x_shape = multidim_transpose(
ir_x_shape, static_axis_boundary, transpose_axis_boundary
)
out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
contracted_x_shape = (
reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
)
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape, te_dtype, te_dtype
)
out = custom_caller(TransposePrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(x, static_axis_boundary, transpose_axis_boundary):
"""
tcast_transpose implementation
"""
assert TransposePrimitive.inner_primitive is not None
transposed_x = TransposePrimitive.inner_primitive.bind(
x,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return transposed_x
@staticmethod
def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
check_valid_batch_dims(batch_dims)
assert TransposePrimitive.outer_primitive is not None
assert static_axis_boundary < 0
(x,) = batched_args
(x_bdim,) = batch_dims
# Minus batch dim.
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim
return (
TransposePrimitive.outer_primitive.bind(
x, static_axis_boundary=x_bdim, transpose_axis_boundary=transpose_axis_boundary
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
return transposed_x_sharding
@staticmethod
def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = transposed_x_sharding
impl = partial(
TransposePrimitive.impl,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return mesh, impl, out_shardings, arg_shardings
register_primitive(TransposePrimitive)
def transpose(
x: jnp.ndarray, static_axis_boundary: int, transpose_axis_boundary: int
) -> jnp.ndarray:
"""
transpose wrapper
"""
if not TransposePrimitive.enabled():
return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary)
return TransposePrimitive.outer_primitive.bind(
x,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
class CastTransposePrimitive(BasePrimitive):
"""
Cast Transpose Primitive
"""
name = "te_cast_transpose"
multiple_results = True
impl_static_args = (4, 5, 6)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
static_axis_boundary,
transpose_axis_boundary
):
"""
te_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
transposed_x_shape = multidim_transpose(
x_aval.shape, static_axis_boundary, transpose_axis_boundary
)
casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return casted_x_aval, casted_xt_aval, updated_amax_aval
@staticmethod
def lowering(
ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
te_cast_transpose_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 2})(
ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
if static_axis_boundary >= 0:
for i in range(static_axis_boundary + 1):
assert ir_x_shape[i] == 1
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(
ir_x_shape, static_axis_boundary, transpose_axis_boundary
)
out_types = [
ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (
reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
)
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
out = custom_caller(
CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2}
)
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
"""
te_cast_transpose implementation
"""
assert CastTransposePrimitive.inner_primitive is not None
casted_x, casted_transposed_x, updated_amax = CastTransposePrimitive.inner_primitive.bind(
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return casted_x, casted_transposed_x, updated_amax
@staticmethod
def batcher(
batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
):
check_valid_batch_dims(batch_dims)
assert CastTransposePrimitive.outer_primitive is not None
assert static_axis_boundary < 0
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, *_ = batch_dims
# Minus batch dim.
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, amax_bdim
return (
CastTransposePrimitive.outer_primitive.bind(
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
@staticmethod
def partition(
out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_cx, local_cxt, local_updated_amax = CastTransposePrimitive.impl(
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
return local_cx, local_cxt, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(CastTransposePrimitive)
def cast_transpose(
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: jnp.dtype,
static_axis_boundary: int,
transpose_axis_boundary: int,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose wrapper
Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
"""
if not CastTransposePrimitive.enabled():
return _jax_cast_transpose(
x, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
)
return CastTransposePrimitive.outer_primitive.bind(
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
class DBiasCastTransposePrimitive(BasePrimitive):
"""
DBias Cast Transpose Primitive
"""
name = "te_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (4, 5, 6)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
dz_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
static_axis_boundary,
transpose_axis_boundary
):
"""
te_dbias_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:])
t_shape = multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
dbias_shape = (*dz_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
(wkspace_info,) = transformer_engine_jax.get_dbias_ct_workspace_sizes(
dz_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = dz_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dbias_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = DBiasCastTransposePrimitive.abstract(
*args, **kwargs
)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(
ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
te_dbias_cast_transpose_p lowering rules
"""
dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_dbias_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 3})(
ctx, dz, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
)
else:
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
contracted_dz_shape = (batch_size, ir_hidden_size)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_dz_shape = multidim_transpose(
ir_dz_shape, static_axis_boundary, transpose_axis_boundary
)
dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [dz, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_dz_shape,
wkspace_aval.shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
)
out = custom_caller(
DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3}
)
return out
@staticmethod
def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
"""
to describe implementation
"""
assert DBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(
batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
check_valid_batch_dims(batch_dims)
assert DBiasCastTransposePrimitive.outer_primitive is not None
dz, amax, scale, scale_inv = batched_args
dz_bdim, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
return (
DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=dz_bdim,
transpose_axis_boundary=transpose_axis_boundary,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(
out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
casted_x_sharding,
casted_transposed_x_sharding,
dbias_shaprding,
amax_sharding,
)
def sharded_impl(dz, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DBiasCastTransposePrimitive)
def dbias_cast_transpose(
dz: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dbias partial fusion wrapper
Return FP8(inputs), dbias
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
if not DBiasCastTransposePrimitive.enabled():
return _jax_dbias_cast_transpose(
dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
)
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
class DActLuDBiasCastTransposePrimitive(BasePrimitive):
"""
DActLu DBias Cast Transpose Primitive
"""
name = "te_dact_lu_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, act_enum
impl_static_args = (5, 6, 7)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
dz_aval,
x_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
static_axis_boundary,
act_enum
): # pylint: disable=unused-argument
"""
te_dact_lu_dbais_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
dbias_shape = (*x_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
(wkspace_info,) = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dact_lu_dbais_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = DActLuDBiasCastTransposePrimitive.abstract(
*args, **kwargs
)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
"""
te_dgated_act_lu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_dact_lu_dbias_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={2: 3})(
ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
)
else:
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (x_batch_size, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [
ir_dz_shape,
x_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape,
wkspace_aval.shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
act_enum,
)
out = custom_caller(
DActLuDBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 3},
)
return out
@staticmethod
def impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype,
static_axis_boundary,
act_enum,
):
"""
to describe implementation
"""
assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
check_valid_batch_dims(batch_dims)
assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return (
DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
act_enum=act_enum,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype,
static_axis_boundary,
act_enum,
mesh,
arg_infos,
result_infos,
):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(
out_dtype,
static_axis_boundary,
act_enum,
mesh,
arg_infos,
result_infos,
):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
casted_x_sharding,
casted_transposed_x_sharding,
dbias_shaprding,
amax_sharding,
)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = (
DActLuDBiasCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DActLuDBiasCastTransposePrimitive)
def dact_lu_dbias_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dact_lu and dbias fusion wrapper
Return FP8(dact_lu(inputs)), dbias
ONLY support non-gated activation type
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
if not DActLuDBiasCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz)
transpose_axis_boundary = -2
return _jax_dbias_cast_transpose(
dx, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
)
act_type_id = ActivationEnum[activation_type]
return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_type_id,
)
class DgatedActLuCastTransposePrimitive(BasePrimitive):
"""
Dgated ActLu Cast Transpose Primitive
"""
name = "te_dgated_act_lu_cast_transpose"
multiple_results = True
impl_static_args = (5, 6, 7) # out_dtype, static_axis_boundary, act_enum
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
dz_aval,
x_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
static_axis_boundary,
act_enum
): # pylint: disable=unused-argument
"""
te_dgated_act_lu_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert x_aval.shape[-2] == 2 # Linear + GeLU
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out, t_out, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
"""
te_dgated_act_lu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_dgated_act_lu_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
)
else:
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
gi_hidden_size = x_shape[-1]
assert ir_hidden_szie == gi_hidden_size
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [
ir_dz_shape,
x_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (x_batch_size, x_shape[-1])
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)
out = custom_caller(
DgatedActLuCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 2},
)
return out
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum):
"""
to describe implementation
"""
assert DgatedActLuCastTransposePrimitive.inner_primitive is not None
out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
return out, t_out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
check_valid_batch_dims(batch_dims)
assert DgatedActLuCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim
return (
DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
act_enum=act_enum,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos
):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DgatedActLuCastTransposePrimitive)
def dgated_act_lu_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose d_gated_act_lu fusion wrapper
Return FP8(dgated_act_lu(inputs))
"""
act_type_id = ActivationEnum[activation_type]
if not DgatedActLuCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz)
return _jax_cast_transpose(
dx,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=-2,
)
return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_type_id,
)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <cudnn.h> #include <cudnn.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <cassert> #include <cassert>
...@@ -33,226 +34,42 @@ ...@@ -33,226 +34,42 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
// Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1.
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
// 2.
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
return pybind11::bytes(str);
}
// 3.
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid opaque object size");
}
return reinterpret_cast<const T *>(opaque);
}
// Packing
struct CustomCallCommonDescriptor {
Shape shape;
DType in_dtype;
DType out_dtype;
size_t act_enum;
};
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype, size_t act_enum = 0);
struct CustomCallCommonWkDescriptor {
Shape shape;
Shape wkshape;
DType in_dtype;
DType out_dtype;
DType wk_dtype;
size_t act_enum;
};
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype,
size_t act_enum = 0);
struct CustomCallNormDescriptor {
size_t batch_size;
size_t hidden_size;
size_t wkspace_size;
DType x_dtype;
DType w_dtype;
DType wkspace_dtype;
bool zero_centered_gamma;
float eps;
int sm_margin;
};
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, DType x_dtype, DType w_dtype,
DType wkspace_dtype, bool zero_centered_gamma,
float eps, int sm_margin);
struct SoftmaxDescriptor {
size_t batch_size;
size_t padding_size;
size_t head_dim;
size_t q_seqlen;
size_t k_seqlen;
DType dtype;
float scale_factor;
};
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t head_dim, size_t q_seqlen, size_t k_seqlen,
DType dtype, float scale_factor);
struct CustomCallFusedAttnDescriptor {
size_t input_batch;
size_t bias_batch;
size_t q_max_seqlen;
size_t kv_max_seqlen;
size_t attn_heads;
size_t num_gqa_groups;
size_t bias_heads;
size_t head_dim;
size_t max_segments_per_seq;
size_t wkspace_size;
float scaling_factor;
float dropout_probability;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
NVTE_QKV_Layout qkv_layout;
DType dtype;
DType wkspace_dtype;
bool is_training;
bool deterministic;
int64_t window_size_left;
int64_t window_size_right;
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right);
// Transpose
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasCastTransposeHandler);
// Activation // Activation
size_t get_activation_len(NVTE_Activation_Type activation_enum);
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler);
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler);
// Normalization // Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps, int sm_margin);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardHandler);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardFP8Handler);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps, int sm_margin);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler); pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type,
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); bool zero_centered_gamma, int sm_margin);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardHandler);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardFP8Handler);
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormBackwardHandler);
// Quantization // Quantization
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
// Softmax pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
std::size_t opaque_len);
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
std::size_t opaque_len); DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x);
// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);
...@@ -266,9 +83,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler); ...@@ -266,9 +83,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Attention // Attention
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
// Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -285,10 +102,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -285,10 +102,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
...@@ -297,9 +110,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -297,9 +110,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); // Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); // Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
// CuBLAS helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment