Commit a207db1d authored by yuguo's avatar yuguo
Browse files
parents fbee8990 69365f88
......@@ -201,8 +201,9 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
max_fp8 = Quantized_Limits<DType>::max_norm;);
// Update scale
compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8,
config.force_pow_2_scales, config.amax_epsilon);
compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
config.amax_epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
"""Transformer Engine bindings for JAX.
This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.
The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training
All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
# pylint: disable=wrong-import-position,wrong-import-order
import sys
import logging
import importlib
import importlib.util
import ctypes
from importlib.metadata import version
import sys
from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
......@@ -41,7 +55,7 @@ def _load_library():
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
_logger.info(
logging.info(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'",
module_name,
......@@ -67,8 +81,10 @@ def _load_library():
_load_library()
from . import flax
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling
from .fp8 import NVTE_FP8_COLLECTION_NAME
from . import quantize
from .quantize import fp8_autocast
from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType
......@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper(
)
__all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource",
"MajorShardingType",
"ShardingResource",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.
This module provides optimized activation functions with quantization support.
"""
from typing import Sequence, Union, Callable, Optional
from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize.tensor import ScaledTensor
from .quantize.quantizer import Quantizer
def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor.
It supports string-based activation types (e.g., 'relu', 'gelu', ('gelu', 'linear')).
Args:
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
"""Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for
custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
return _output
def _activation_fwd_rule(x, activation_type, quantizer):
"""Forward pass rule for activation function.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
if isinstance(fwd_output, ScaledTensor):
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g):
"""Backward pass rule for activation function.
Args:
activation_type: Sequence of activation functions
ctx: Context from forward pass
g: Gradient from upstream
Returns:
Gradient with respect to input
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None)
_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule)
......@@ -378,6 +378,44 @@ def _mask_to_seqlens_offset(mask, max_segments_per_seq):
return q_seqlen, q_offset, kv_seqlen, kv_offset
def _fast_causal_adjust_seqlen_and_offsets(
segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
):
# The assumption is that for any segment tokens respect causal ordering except at the ends
# of the segment. This allows us to tweak the length and offset by only looking at the start
# and end tokens between segments.
is_active_segment = jnp.logical_and(q_len > 0, kv_len > 0)
q_seq_id_start = jnp.take(segment_pos_q, q_offset[..., :-1], fill_value=-1)
kv_seq_id_start = jnp.take(segment_pos_kv, kv_offset[..., :-1], fill_value=-1)
skip_start_token = jnp.logical_and(kv_seq_id_start > q_seq_id_start, is_active_segment).astype(
jnp.int32
)
q_len -= skip_start_token
q_offset += jnp.insert(skip_start_token, skip_start_token.shape[-1], 0, axis=-1)
q_seq_id_end = jnp.take(segment_pos_q, q_offset[..., 1:] - 1, fill_value=-1)
kv_seq_id_end = jnp.take(segment_pos_kv, kv_offset[..., 1:] - 1, fill_value=-1)
skip_end_token = jnp.logical_and(kv_seq_id_end > q_seq_id_end, is_active_segment).astype(
jnp.int32
)
kv_len -= skip_end_token
return q_len, kv_len, q_offset, kv_offset
def _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
):
q_len, q_offset = _get_seqlens_and_offsets(segment_ids_q, max_segments_per_seq)
kv_len, kv_offset = _get_seqlens_and_offsets(segment_ids_kv, max_segments_per_seq)
return _fast_causal_adjust_seqlen_and_offsets(
segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
)
def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q,
segment_ids_kv,
......@@ -387,6 +425,25 @@ def _segment_ids_pos_to_seqlens_offsets(
window_size,
max_segments_per_seq,
):
# TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here.
# Computing the full mask is expensive due to quadratic expansion of Q * KV masking.
# Assumptions for cudnn causal mask correctness.
# 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0]
# 2. No intra-segment padding, only inter-segment paddding allowed
# 3. Only start or end token within a segment may violate the causal order relationship
# 1 5 9 0 4 8 10 0 4 8
# 0 x x
# 4 x x x x x
# 8 x x x x x x x x
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)
# (1 = attend, 0 = masked)
segment_mask = make_attention_mask(
segment_ids_q,
......
......@@ -7,4 +7,4 @@ from .attention import *
from .normalization import *
from .quantization import *
from .softmax import *
from .transpose import *
from .gemm import *
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
from typing import Tuple, Sequence, Union, Callable
from typing import Sequence, Union, Callable, Optional, Tuple
import operator
from functools import reduce, partial
from packaging import version
......@@ -10,31 +10,38 @@ from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.sharding import PartitionSpec
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
check_valid_batch_dims,
jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype,
get_padded_spec,
is_ffi_enabled,
check_valid_batch_dims,
multidim_transpose,
try_apply_delayed_scaling_2x_war,
should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding,
)
from .quantization import _jax_quantize_dbias, _jax_dbias, quantize_dbias
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
Quantizer,
QuantizeAxis,
DelayedScaleQuantizer,
ScalingMode,
)
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["act_lu", "dact_lu", "act_lu_fp8"]
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
ActivationEnum = {
......@@ -66,448 +73,1053 @@ def _convert_to_activation_function(fn_or_string):
raise ValueError(f"Unsupported {fn_or_string} to an activation function")
def _jax_act_lu(inputs, activation_type):
"""
JAX native activation implementation
"""
x = jnp.split(inputs, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
return x
class ActLuPrimitive(BasePrimitive):
"""
Activation Forward Primitive
ActLu Primitive
"""
name = "te_act_lu"
multiple_results = False
name = "te_act_lu_ffi"
multiple_results = True
impl_static_args = (
2,
3,
4,
5,
6,
7,
8,
9,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
inner_primitive = None
outer_primitive = None
impl_static_args = (1,)
@staticmethod
def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument
def abstract(
x_aval,
scale_aval,
*,
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
act_lu abstract
te_act_lu_p abstract
"""
del act_enum, act_len, scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
out_shape = (
*x_aval.shape[:-2],
1,
x_aval.shape[-1],
)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(out_shape[:-2] + (out_shape[-1],), is_padded=not is_outer)
x_shape = x_aval.shape
assert x_shape[-2] == 2 or x_shape[-2] == 1
hidden_size = x_shape[-1]
batch_shapes = x_shape[:-2]
out_aval = x_aval
out_shape = (batch_shapes) + (hidden_size,)
out_aval = out_aval.update(shape=out_shape, dtype=dtype)
if len(rowwise_scale_inv_shape) > 1:
rowwise_scale_inv_shape = (
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
)
if len(colwise_scale_inv_shape) > 1:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
if is_2x:
colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
return out_aval
return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
@staticmethod
def lowering(ctx, x, *, act_enum):
def lowering(
ctx,
x,
scale,
*,
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
act_lu lowering rules
te_gated_act_lu_p lowering rules
"""
(x_aval,) = ctx.avals_in
del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
if is_ffi_enabled():
name = "te_act_lu_ffi"
out = ffi.ffi_lowering(name)(ctx, x, act_enum=act_enum)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
out_types = [
ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
assert scale_aval is None or scale_aval.dtype == jnp.float32
hidden_size = ir_x_shape[-1]
batch_size = reduce(operator.mul, ir_x_shape[:-2])
in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor(
(batch_size, hidden_size), in_dtype, in_dtype, act_enum
out = ffi.ffi_lowering(ActLuPrimitive.name)(
ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x
)
out = custom_caller(ActLuPrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(x, act_enum):
def impl(
x,
scale,
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
to describe implementation
"""
del is_outer
assert ActLuPrimitive.inner_primitive is not None
out = ActLuPrimitive.inner_primitive.bind(x, act_enum=act_enum)
return out
out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
ActLuPrimitive.inner_primitive.bind(
x,
scale,
out_dtype=out_dtype,
act_enum=act_enum,
act_len=act_len,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False,
)
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(out.shape[:-2] + (out.shape[-1],), is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
rowwise_scale_inv_shape = (
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
)
if is_2x:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, act_enum):
def batcher(
batched_args,
batch_dims,
*,
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
act_lu batcher
to describe batch rules for vmap
"""
del act_len, is_outer
check_valid_batch_dims(batch_dims)
assert ActLuPrimitive.outer_primitive is not None
(inputs,) = batched_args
(inputs_bdim,) = batch_dims
x, scale = batched_args
x_bdim, scale_bdim = batch_dims
amax_bdim = scale_bdim
out_bdims = inputs_bdim
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
return (
ActLuPrimitive.outer_primitive.bind(
x,
scale,
out_dtype=out_dtype,
act_enum=act_enum,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
"""
act_lu infer_sharding_from_operands
"""
del result_infos, act_enum # Unused.
def infer_sharding_from_operands(
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del (
out_dtype,
result_infos,
act_enum,
scale_dtype,
scale_shapes,
act_len,
is_outer,
) # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
return out_sharding
out_spec = (*x_spec[:-2], None, x_spec[-2])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec)
else:
colwise_out_spec = out_spec
else:
colwise_out_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
)
return (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
)
@staticmethod
def partition(act_enum, mesh, arg_infos, result_infos):
"""
act_lu partitioning
"""
del result_infos
def partition(
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer # Unused.
x_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
out_spec = (*x_spec[:-1], x_spec[-1])
if act_len == 2 and x_spec[-1] is None:
# Ensure last axis is partitioned and not the gating axis
out_spec = (*x_spec[:-2], None, x_spec[-2])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec)
else:
colwise_out_spec = out_spec
else:
colwise_out_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")
def sharded_impl(x):
return ActLuPrimitive.impl(x, act_enum=act_enum)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec))
arg_shardings = tuple(arg_shardings)
out_shardings = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
)
return mesh, sharded_impl, out_sharding, arg_shardings
def sharded_impl(x, scale):
local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
ActLuPrimitive.impl(
x,
scale,
out_dtype=out_dtype,
act_enum=act_enum,
act_len=act_len,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True,
)
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
register_primitive(ActLuPrimitive)
return (
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
global_updated_amax,
)
return mesh, sharded_impl, out_shardings, arg_shardings
def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
"""
act_lu wrapper
Return act_lu(inputs)
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
"""
if not ActLuPrimitive.enabled():
return _jax_act_lu(inputs, activation_type)
act_type_id = ActivationEnum[activation_type].value
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)
register_primitive(ActLuPrimitive)
class DActLuPrimitive(BasePrimitive):
class DActLuDBiasQuantizePrimitive(BasePrimitive):
"""
Dgated ActLu Primitive
DActLu DBias Cast Transpose Primitive
"""
name = "te_dact_lu"
multiple_results = False
name = "te_dact_dbias_quantize_ffi"
multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
inner_primitive = None
outer_primitive = None
impl_static_args = (2,)
@staticmethod
def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument
def abstract(
dz_aval,
x_aval,
scale_aval,
*,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
):
"""
dact_lu abstract
te_dact_dbias_quantize_p abstract
"""
del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
for axis in range(len(dz_aval.shape) - 1):
assert dz_aval.shape[axis] == x_aval.shape[axis]
assert x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1
assert scale_aval.dtype == jnp.float32
ir_hidden_size = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert act_len * ir_hidden_size == gi_hidden_size
out_shape = x_aval.shape
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
i_hidden_size = dz_aval.shape[-1]
g_hidden_size = x_aval.shape[-1]
assert i_hidden_size == g_hidden_size
out_aval = x_aval
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
return out_aval
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
@staticmethod
def lowering(ctx, dz, x, *, act_enum):
"""
dact_lu lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
if is_ffi_enabled():
name = "te_dact_lu_ffi"
out = ffi.ffi_lowering(name)(ctx, dz, x, act_enum=act_enum)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
if is_2x:
# Don't transpose output for MXFP8
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
t_shape = out_shape
else:
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
gi_shape = gi_type.shape
# assert ir_in_shape == gi_shape
for axis in range(len(ir_in_shape) - 1):
assert ir_in_shape[axis] == gi_shape[axis]
ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
g_hidden_size = gi_shape[-1]
assert i_hidden_size == g_hidden_size
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [dz, x]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor(
(ir_batch_size, i_hidden_size), in_dtype, in_dtype, act_enum
)
out = custom_caller(DActLuPrimitive.name, args, opaque, False)
t_shape = multidim_transpose(out_shape)
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
return out
if is_dbias:
dbias_shape = gi_hidden_size
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
(wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
scaling_mode,
is_2x,
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
dbias_aval,
wkspace_aval,
)
@staticmethod
def impl(dz, x, act_enum):
def outer_abstract(*args, **kwargs):
"""
dact_lu implementation
te_dact_dbias_quantize_p outer abstract
"""
assert DActLuPrimitive.inner_primitive is not None
dx = DActLuPrimitive.inner_primitive.bind(dz, x, act_enum=act_enum)
return dx
(out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
@staticmethod
def batcher(batched_args, batch_dims, *, act_enum):
def lowering(
ctx,
dz,
x,
scale,
*,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
):
"""
dact_lu batcher
te_dact_dbias_quantize_p lowering rules
"""
check_valid_batch_dims(batch_dims)
assert DActLuPrimitive.outer_primitive is not None
dz, x = batched_args
_, x_bdim = batch_dims
out_bdims = x_bdim
return DActLuPrimitive.outer_primitive.bind(dz, x, act_enum=act_enum), out_bdims
del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
dz_aval, x_aval, scale_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)(
ctx,
dz,
x,
scale,
scaling_mode=scaling_mode,
is_2x=is_2x,
is_dbias=is_dbias,
act_enum=int(act_enum),
)
@staticmethod
def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
def impl(
dz,
x,
scale,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
):
"""
dact_lu infer_sharding_from_operands
te_dact_dbias_quantize_p impl
"""
del result_infos, act_enum # Unused.
act_lu_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec))
return dx_sharding
del is_outer
assert DActLuDBiasQuantizePrimitive.inner_primitive is not None
(out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
DActLuDBiasQuantizePrimitive.inner_primitive.bind(
dz,
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
is_outer=False,
)
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) # Exclude wkspace
@staticmethod
def partition(act_enum, mesh, arg_infos, result_infos):
def batcher(
batched_args,
batch_dims,
*,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
):
"""
dact_lu partition
to describe batch rules for vmap
"""
del result_infos
dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = dx_sharding
del is_outer
check_valid_batch_dims(batch_dims)
assert DActLuDBiasQuantizePrimitive.outer_primitive is not None
dz, x, scale = batched_args
_, x_bdim, scale_bdim = batch_dims
out_bdims = (
x_bdim, # rowwise output
scale_bdim, # rowwise scale_inv
x_bdim, # colwise output
scale_bdim, # colwise scale_inv
scale_bdim, # amax
x_bdim, # dbias
)
return (
DActLuDBiasQuantizePrimitive.outer_primitive.bind(
dz,
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
),
out_bdims,
)
def sharded_impl(dz, x):
return DActLuPrimitive.impl(dz, x, act_enum=act_enum)
@staticmethod
def infer_sharding_from_operands(
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
mesh,
arg_infos,
result_infos,
):
del out_dtype, result_infos, act_enum
del scale_dtype, scale_shapes, is_dbias, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1])
return mesh, sharded_impl, out_shardings, arg_shardings
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec)
else:
colwise_x_spec = x_spec
else:
colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
)
return (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_shaprding,
)
register_primitive(DActLuPrimitive)
@staticmethod
def partition(
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec), desc="out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec)
else:
colwise_x_spec = x_spec
else:
colwise_x_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
)
def dact_lu(
inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]
) -> jnp.ndarray:
"""
dact_lu fusion wrapper
Return dgated_act_lu(inputs)
"""
if not DActLuPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs)
return vjp_func(inputs)[0]
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = (
arg_shardings[1],
arg_shardings[1],
*arg_shardings[2:],
) # dz and x are the same
out_shardings = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_shaprding,
)
act_type_id = ActivationEnum[activation_type].value
return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)
def sharded_impl(dz, x, scale):
(out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
DActLuDBiasQuantizePrimitive.impl(
dz,
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
is_outer=True,
)
)
if is_dbias:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
else:
global_dbias = local_dbias
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
class ActLuFp8Primitive(BasePrimitive):
"""
ActLu FP8 Primitive
"""
return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
name = "te_act_lu_fp8"
multiple_results = True
impl_static_args = (4, 5) # out_dtype, act_enum
inner_primitive = None
outer_primitive = None
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def abstract(
x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, act_enum
): # pylint: disable=unused-argument
register_primitive(DActLuDBiasQuantizePrimitive)
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
"""
te_act_lu_p abstract
JAX native activation implementation
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
assert x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2
hidden_size = x_aval.shape[-1]
batch_shape = x_aval.shape[:-2]
out_shape = (batch_shape) + (hidden_size,)
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
x = jnp.split(inputs, len(activation_type), axis=-1)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = reduce(operator.mul, acts)
if quantizer:
return quantizer.quantize(x)
return x
return out_aval, updated_amax_aval
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum):
def _jax_quantize_dact_dbias(
dz: jnp.ndarray,
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
is_dbias: bool = True,
quantizer: Optional[Quantizer] = None,
):
"""
te_gated_act_lu_p lowering rules
JAX implementation of dact_lu and dbias with optional quantization
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_act_lu_fp8_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
ctx, x, amax, scale, scale_inv, act_enum=act_enum
_, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
)
(dx,) = vjp_func(dz.astype(jnp.float32))
dbias = None
if is_dbias:
dbias = _jax_dbias(dx).astype(x.dtype)
if quantizer is not None:
dx = quantizer.quantize(dx, dq_dtype=x.dtype)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_x_shape[-1]
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor(
(batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)
dx = dx.astype(x.dtype)
out = custom_caller(
ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
)
return dx, dbias
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype, act_enum):
"""
to describe implementation
def act_lu(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization.
Args:
x: Input tensor to be processed.
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
If quantizer is None:
The activated input tensor with the same dtype as input.
If quantizer is provided:
A ScaledTensor containing the quantized activated input.
"""
assert ActLuFp8Primitive.inner_primitive is not None
out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
act_type_id = ActivationEnum[activation_type].value
if not ActLuPrimitive.enabled():
return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support 2x quantization for DelayedScaling yet
war_output = try_apply_delayed_scaling_2x_war(
f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
)
if war_output is not None:
return war_output
scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-1], x.shape[-1] // len(activation_type))
if quantizer is None:
x = x.reshape((-1, len(activation_type), x.shape[-1] // len(activation_type)))
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
x,
scale,
out_dtype=x.dtype,
act_enum=act_type_id,
act_len=len(activation_type),
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True,
)
return out, updated_amax
out = out.reshape(output_shape)
return out
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, act_enum):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims(batch_dims)
assert ActLuFp8Primitive.outer_primitive is not None
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, _, _ = batch_dims
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
x = x.reshape((*x.shape[:-1], len(activation_type), x.shape[-1] // len(activation_type)))
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
) = ActLuPrimitive.outer_primitive.bind(
x,
scale,
out_dtype=quantizer.q_dtype,
act_enum=act_type_id,
act_len=len(activation_type),
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(output_shape),
is_outer=True,
)
out_bdims = x_bdim, amax_bdim
return (
ActLuFp8Primitive.outer_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
),
out_bdims,
rowwise_casted_output = rowwise_casted_output.reshape(output_shape)
if len(rowwise_scale_inv.shape) > 1:
rowwise_scale_inv = jnp.squeeze(rowwise_scale_inv, axis=-2) # Remove act axis
if quantizer.q_axis in (QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE):
colwise_output_shape = output_shape
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
colwise_output_shape = multidim_transpose(output_shape)
colwise_casted_output = colwise_casted_output.reshape(colwise_output_shape)
if len(colwise_scale_inv.shape) > 1:
colwise_scale_inv = jnp.squeeze(colwise_scale_inv, axis=-2) # Remove act axis
quantizer.update(updated_amax)
return ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
)
@staticmethod
def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, act_enum, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (out_sharding, amax_sharding)
def quantize_dact_dbias(
dz: jnp.ndarray,
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True,
quantizer: Optional[Quantizer] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization.
Args:
dz: Gradient of the output with respect to the activation output.
x: Input tensor that was processed by the forward pass.
Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
- The gradient of the activation with respect to the input.
- The gradient of the activation with respect to the bias.
"""
def sharded_impl(x, amax, scale, scale_inv):
local_x, local_amax = ActLuFp8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
if not DActLuDBiasQuantizePrimitive.enabled():
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
return local_x, global_updated_amax
# TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
return mesh, sharded_impl, out_shardings, arg_shardings
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = quantize_dact_dbias(
dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=None
)
return quantize_dbias(out, is_dbias=True, quantizer=quantizer)
is_gated = len(activation_type) == 2
# TE/common does not support DelayedScaling2x for gated-act yet
if is_gated:
war_output = try_apply_delayed_scaling_2x_war(
f=quantize_dact_dbias,
dz=dz,
x=x,
activation_type=activation_type,
is_dbias=is_dbias,
quantizer=quantizer,
)
if war_output is not None:
return war_output
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
if quantizer is None:
output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
scale_shapes=((), ()), # unused
is_dbias=False,
act_enum=act_type_id,
act_len=len(activation_type),
is_outer=True,
)
dbias = None
if is_dbias:
dbias = _jax_dbias(output).astype(x.dtype)
return output.astype(x.dtype), dbias
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
# TE/common dact_dbias_quantize does not support gated act yet
if is_dbias and is_gated:
dgated = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
)
# TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
out, dbias = _jax_quantize_dbias(dgated, quantizer=quantizer, dq_dtype=x.dtype)
else:
out, dbias = quantize_dbias(
dgated,
quantizer=quantizer,
is_dbias=True,
dq_dtype=x.dtype,
)
return out, dbias
out_shape = x.shape
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
dz,
x,
scale,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(out_shape),
is_dbias=is_dbias,
act_enum=act_type_id,
act_len=len(activation_type),
is_outer=True,
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
out = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
)
register_primitive(ActLuFp8Primitive)
return out, dbias
def act_lu_fp8(
def dact_lu(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: jnp.dtype,
activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
act wrapper
Return FP8(act_lu(x))
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""
if not ActLuFp8Primitive.enabled():
act_lu_output = _jax_act_lu(x, activation_type)
casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
return casted_output, updated_amax
Backward pass for activation with optional quantization.
act_type_id = ActivationEnum[activation_type].value
return ActLuFp8Primitive.outer_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id
Args:
dz: Gradient tensor from upstream.
x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient.
Returns:
The gradient of the activation with respect to the input.
"""
output, _ = quantize_dact_dbias(
dz=dz,
x=x,
activation_type=activation_type,
is_dbias=False,
quantizer=quantizer,
)
return output
......@@ -13,8 +13,6 @@ from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax
......@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import (
)
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
check_valid_batch_dims,
jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype,
get_padded_spec,
get_cudnn_version,
is_ffi_enabled,
)
from ..sharding import (
global_mesh_resource,
......@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
Fused Attention Forward Primitive
"""
name = "te_fused_attn_forward"
name = "te_fused_attn_forward_ffi"
multiple_results = True
impl_static_args = (13,)
inner_primitive = None
......@@ -400,9 +396,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled():
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
ctx,
q,
k,
......@@ -436,54 +430,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(
......@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
Fused Attention Backward Primitive
"""
name = "te_fused_attn_backward"
name = "te_fused_attn_backward_ffi"
multiple_results = True
impl_static_args = (16,)
inner_primitive = None
......@@ -813,9 +759,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled():
name = "te_fused_attn_backward_ffi"
out = ffi.ffi_lowering(name)(
return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
ctx,
q,
k,
......@@ -852,57 +796,6 @@ class FusedAttnBwdPrimitive(BasePrimitive):
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod
def impl(
......
......@@ -6,6 +6,7 @@ import os
import re
from abc import ABCMeta, abstractmethod
from functools import partial
from packaging import version
from jax.extend import core
from jax.interpreters import xla, mlir
......@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
from jax._src import dispatch
import jax
import transformer_engine_jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
class BasePrimitive(metaclass=ABCMeta):
"""
......@@ -120,3 +129,7 @@ def register_primitive(cls):
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
)
cls.outer_primitive = outer_p
for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom call"""
from dataclasses import dataclass
from enum import IntEnum
from packaging import version
import jax
from jax.interpreters import mlir
import transformer_engine_jax
from .misc import is_ffi_enabled
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
try:
from jaxlib.hlo_helpers import custom_call
except ImportError:
# Newer JAX changed its API. But we want to support a few JAX
# version, so we still need this import.
pass
class CustomCallAPIVersion(IntEnum):
"""Enum for selecting between old and new custom call registration API"""
OPAQUE = 0
FFI = 1
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
@dataclass
class CustomCallArgsWrapper:
"""
wrapper of XLA custom call args
"""
def __init__(
self,
output_types,
operands,
operand_shapes,
operand_specific_layouts=None,
output_specific_layouts=None,
):
self.output_types = output_types
self.operands = operands
self.operand_layouts = CustomCallArgsWrapper.generate_layouts(
operand_shapes, operand_specific_layouts
)
output_shapes = [x.shape for x in output_types]
self.output_layouts = CustomCallArgsWrapper.generate_layouts(
output_shapes, output_specific_layouts
)
@staticmethod
def generate_layouts(shapes, specific_layouts):
"""
setup layouts for XLA custom call
"""
def default_layout(shape):
return range(len(shape) - 1, -1, -1)
if specific_layouts is None:
specific_layouts = {}
layouts = []
for idx, shape in enumerate(shapes):
if idx in specific_layouts:
layouts.append(specific_layouts[idx])
else:
layouts.append(default_layout(shape))
return layouts
def custom_caller(name, args, opaque, has_side_effect, **kwargs):
"""
XLA custom call warpper
"""
if hasattr(mlir, "custom_call"):
out = mlir.custom_call(
name,
result_types=args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
).results
else:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out = custom_call( # pylint: disable=too-many-function-args
name,
args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
)
return out
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""
from typing import Tuple, Sequence, Union, Dict, List
from functools import partial, reduce
import operator
from transformer_engine_jax import get_device_compute_capability
import jax
import jax.numpy as jnp
from .base import BasePrimitive, register_primitive
from ..quantize import (
ScaledTensor,
ScalingMode,
Quantizer,
QuantizeConfig,
noop_quantizer_set,
)
__all__ = ["gemm", "grouped_gemm"]
num_cublas_streams = 4
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if get_device_compute_capability(0) >= 90:
return 33_554_432
return 4_194_304
class GroupedGemmPrimitive(BasePrimitive):
"""
Primitive for grouped GEMM
"""
name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
lhs_contig_aval,
lhs_scale_contig_aval,
rhs_contig_aval,
rhs_scale_contig_aval,
bias_contig_aval,
dim_list_aval,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
):
del lhs_contig_aval, lhs_scale_contig_aval
del rhs_contig_aval, rhs_scale_contig_aval
del bias_contig_aval, dim_list_aval
del num_gemms, scaling_mode
out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype)
wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8)
return (out_flat_aval, wkspace_aval)
@staticmethod
def outer_abstract(*args, **kwargs):
(out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
return out_aval
@staticmethod
def lowering(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
del out_dtype, out_flat_size
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=int(scaling_mode),
)
@staticmethod
def impl(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
)
return out[0] # out is [out_flat, wkspace], only return out_flat
register_primitive(GroupedGemmPrimitive)
def _shape_normalization(x, dimension_numbers, already_transposed: bool = False):
orig_order = list(range(x.ndim))
contracting_dims, batch_dims = dimension_numbers
contracting_order = [d for d in orig_order if d in contracting_dims]
batch_order = [d for d in orig_order if d in batch_dims]
non_contracting_order = [
d for d in orig_order if d not in contracting_dims and d not in batch_dims
]
batch_shape = [x.shape[d] for d in batch_order]
rows_shape = [x.shape[d] for d in non_contracting_order]
cols_shape = [x.shape[d] for d in contracting_order]
new_order = batch_order + non_contracting_order + contracting_order
rows, cols, batches = (
reduce(operator.mul, rows_shape, 1),
reduce(operator.mul, cols_shape, 1),
reduce(operator.mul, batch_shape, 1),
)
# Remove this transpose when non-TN dot is supported
if not already_transposed:
t = jnp.transpose(x, new_order)
else:
t = x
return jnp.reshape(t, (batches, rows, cols))
def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)
def _dequantize(x, scale_inv, dq_dtype):
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(
jax.jit,
static_argnums=(
2,
3,
4,
),
)
def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T")
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
)
return out_3d
def _jax_gemm_delayed_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs.data.shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs.data.shape, rhs_contract)
precision = (
jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
)
out_3d = __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
# Reshape [B, M, N] -> [..., M, N]
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
def _jax_gemm_mxfp8_1d(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""
JAX GEMM for MXFP8 via scaled_matmul
"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
expected_lhs_is_colwise = lhs_contract[-1] != lhs.data.ndim - 1
expected_rhs_is_colwise = rhs_contract[-1] != rhs.data.ndim - 1
assert lhs.is_colwise is expected_lhs_is_colwise, (
f"LHS with unexpected quantize dimension.\nExpect is_colwise={expected_lhs_is_colwise}, got"
f" {lhs.is_colwise}"
)
assert rhs.is_colwise is expected_rhs_is_colwise, (
f"RHS with unexpected quantize dimension.\nExpect is_colwise={expected_rhs_is_colwise}, got"
f" {rhs.is_colwise}"
)
# Reshape + Transpose (if needed)
# [..., M, K] -> [1, reduce(..., M), K]
# [..., K, M] -> [1, reduce(..., M), K]
lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch))
rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch))
lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch))
rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch))
# Slice out the padding as scaled_matmul does not support padded scales yet
lhs_scale_3d = jnp.asarray(lhs_scale_3d[:, : lhs_3d.shape[1], : int(lhs_3d.shape[2] / 32)])
rhs_scale_3d = jnp.asarray(rhs_scale_3d[:, : rhs_3d.shape[1], : int(rhs_3d.shape[2] / 32)])
# JAX scaled_matmul only supports NT now (TN-gemm)
# * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = scaled_matmul_wrapper(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
)
# Reshape [1, reduce(..., M), N] -> [..., M, N]
lhs_remain_shape = tuple(
lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract
)
rhs_remain_shape = tuple(
rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract
)
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
def _jax_gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set,
) -> jnp.ndarray:
"""
FP8 GEMM via JAX
"""
dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
return _jax_gemm_fp8_impl(lhs, rhs)
if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor):
if quantizer_set != noop_quantizer_set:
assert type(quantizer_set.x) is type(quantizer_set.kernel)
(((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
# Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm)
lhs_q = quantizer_set.x.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = quantizer_set.kernel.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return _jax_gemm_fp8_impl(lhs_q, rhs_q)
if (
isinstance(lhs, jnp.ndarray)
and isinstance(rhs, jnp.ndarray)
and quantizer_set == noop_quantizer_set
):
return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype)
raise NotImplementedError("Not supporting multiplication of ScaledTensor and jnp.array")
def gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set,
) -> jnp.ndarray:
"""General matrix multiplication with optional quantization.
Args:
lhs: First input matrix.
rhs: Second input matrix.
contracting_dims: Tuple of two sequences representing the contracting dimensions.
The first sequence represents the contracting dimensions of the first matrix,
and the second sequence represents the contracting dimensions of the second matrix.
quantizer_set: Set of quantizers for FP8 quantization of the output.
If None, no quantization is applied and the output has the same dtype as the inputs.
Returns:
If quantizer_set is None:
The matrix multiplication result.
Shape: (M, N)
Dtype: Same as input dtype
If quantizer_set is provided:
A ScaledTensor containing the quantized matrix multiplication result.
"""
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
def swizzled_scale(scales):
"""Swizzle the scale tensor for FP8 GEMM"""
assert scales.ndim == 2
rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
return scales
def grouped_gemm(
lhs_list: List[Union[jnp.ndarray, ScaledTensor]],
rhs_list: List[Union[jnp.ndarray, ScaledTensor]],
contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]],
bias_list: List[jnp.ndarray] = None,
) -> List[jnp.ndarray]:
"""Grouped GEMM for multiple pairs of tensors."""
assert (
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
# Flatten inputs and save their shapes
num_gemms = len(lhs_list)
out_flat_size = 0
dims = []
lhs_contig_ = []
rhs_contig_ = []
lhs_scale_inv_contig_ = []
rhs_scale_inv_contig_ = []
bias_contig_ = []
out_offsets = []
remain_shape_list = []
num_gemms = len(lhs_list)
for i in range(num_gemms):
lhs = lhs_list[i]
rhs = rhs_list[i]
contracting_dims = contracting_dims_list[i]
dim_nums = (contracting_dims, ((), ()))
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
scaling_mode = lhs.scaling_mode
lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
out_dtype = lhs.dtype
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
# Note: if _shape_normalization() is updated to support non-TN, need to update here
# already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
# x.shape = [D1, D2]
# contracting_dims = (1, ) --> output.shape = [1, D1, D2]
# contracting_dims = (0, ) --> output.shape = [1, D2, D1]
bm = lhs_remain_shape[0]
bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1]
kr = rhs_3d.shape[-1]
remain_shape_list.append(((bm,), (bn,)))
assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})"
k = kl
if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0):
print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(
f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples"
" of 16"
)
assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0
dims.append((bm, bn, k))
lhs_contig_.append(lhs_3d.reshape(-1))
rhs_contig_.append(rhs_3d.reshape(-1))
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1))
if bias_list is not None:
bias_contig_.append(bias_list[i].reshape(-1))
out_flat_size += bm * bn
out_offsets.append(out_flat_size)
lhs_contig = jnp.concatenate(lhs_contig_)
rhs_contig = jnp.concatenate(rhs_contig_)
lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_)
rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_)
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
dim_list = jnp.array(dims, dtype=jnp.int32)
# Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=scaling_mode,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
)
# Split the output back into tensors
out_offsets = jnp.array(out_offsets)
out_flat_list = jnp.split(out_contig, out_offsets[:-1])
out_tensors = []
for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list):
out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape))
return out_tensors
......@@ -11,14 +11,17 @@ from packaging.version import Version as PkgVersion
import numpy as np
import jax.numpy as jnp
import jax
from jax import dtypes
import jax.numpy as jnp
from jax.interpreters.mlir import dtype_to_ir_type
from transformer_engine_jax import DType as TEDType
import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis
TEDType = transformer_engine_jax.DType
def te_dtype_to_jax_dtype(te_dtype):
......@@ -104,7 +107,7 @@ def normalize_axis_boundary(axis, ndim):
return axis if axis >= 0 else ndim + axis
def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1):
"""
te_cast_transpose_p multi-dims transpose
......@@ -158,17 +161,6 @@ def jax_version_meet_requirement(version: str):
return jax_version >= jax_version_required
def is_ffi_enabled():
"""
Helper function checking if XLA Custom Call with FFI is enabled
"""
is_supported = jax_version_meet_requirement("0.4.35")
# New APIs with FFI are enabled by default
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
return is_supported and is_enabled
def get_xla_flag(flag: str, default=None, cast=str):
"""
Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
......@@ -189,3 +181,86 @@ def get_xla_flag(flag: str, default=None, cast=str):
if name == flag:
return True
return default
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
"""
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
calculate dbias separately. This function checks if the workaround should be applied.
"""
arch_l_100 = False
for local_gpu_id in range(len(jax.local_devices())):
if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100:
arch_l_100 = True
break
return (
quantizer is not None
and quantizer.q_axis == QuantizeAxis.ROWWISE
and arch_l_100
and is_dbias
)
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
"""
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
If 'f' returns a tuple, the first output must be the only ScaledTensor output.
@param f: function to call
@param args: positional arguments to pass to 'f'
@param quantizer: quantizer to use
@param kwargs: keyword arguments to pass to 'f'
@return: the output of 'f' with the colwise output calculated
"""
should_apply_war = (
quantizer is not None
and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
and quantizer.is_2x2x()
)
if not should_apply_war:
return None
# 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX
quantizer.q_axis = QuantizeAxis.ROWWISE
rowwise = f(*args, **kwargs, quantizer=quantizer)
other_outputs = None
if isinstance(rowwise, tuple):
other_outputs = rowwise[1:]
rowwise = rowwise[0]
quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE
colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1)))
output_2x = ScaledTensorFactory.create(
data=rowwise.data,
scale_inv=rowwise.scale_inv,
colwise_data=colwise_data,
colwise_scale_inv=rowwise.scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=rowwise.dq_dtype,
q_axis=QuantizeAxis.ROWWISE_COLWISE,
layout=quantizer.get_layout(),
)
if other_outputs is not None:
return (output_2x,) + other_outputs
return output_2x
class NamedSharding(jax.sharding.NamedSharding):
"""
Wrapper around jax.sharding.NamedSharding that adds a string description field as metadata for easier debugging.
"""
def __init__(self, *args, desc: str = None, **kwargs):
super().__init__(*args, **kwargs)
self.desc = desc
def __repr__(self):
return f"NamedSharding({self.mesh}, {self.spec}, desc={self.desc})"
def duplicate_with_new_description(self, desc: str):
"""
Create a new NamedSharding with the same mesh and spec but with a new description.
"""
return NamedSharding(self.mesh, self.spec, desc=desc)
......@@ -2,33 +2,38 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for normalization"""
import operator
import os
import warnings
from functools import partial, reduce, cache
import operator
from functools import partial, cache, reduce
from typing import Optional, Union
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.sharding import PartitionSpec
import transformer_engine_jax
from transformer_engine_jax import NVTE_Norm_Type
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
get_padded_spec,
check_valid_batch_dims,
jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype,
is_ffi_enabled,
NamedSharding,
)
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
Quantizer,
QuantizeAxis,
DelayedScaleQuantizer,
ScalingMode,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
......@@ -41,8 +46,8 @@ __all__ = [
"layernorm_bwd",
"rmsnorm_fwd",
"rmsnorm_bwd",
"layernorm_fwd_fp8",
"rmsnorm_fwd_fp8",
"normalization_fwd",
"normalization_bwd",
]
......@@ -58,325 +63,520 @@ def get_backward_sm_margin():
return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
class LayerNormFwdPrimitive(BasePrimitive):
class NormFwdPrimitive(BasePrimitive):
"""
Layer Normalization Forward Primitive
Layer Normalization Forward FP8 Primitive
"""
name = "te_layernorm_forward"
name = "te_norm_forward_ffi"
multiple_results = True
impl_static_args = (3, 4) # zero_centered_gamma, epsilon
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, gamma_aval, beta_aval, **kwargs):
def abstract(
x_aval,
scale_aval,
gamma_aval,
beta_aval,
*,
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
LayerNorm fwd inner primitive abstract
"""
del scale_shapes
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
mu_rsigama_dtype = jnp.float32
out_aval = x_aval
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.LayerNorm:
assert gamma_aval.size == beta_aval.size
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
True,
kwargs["zero_centered_gamma"],
kwargs["epsilon"],
(wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # itype
jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype
jax_dtype_to_te_dtype(out_dtype),
norm_type,
scaling_mode.value,
zero_centered_gamma,
epsilon,
get_forward_sm_margin(),
is_2x,
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x_aval.shape, is_padded=not is_outer
)
wkspace_aval = out_aval.update(
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
colwise_out_aval = jax.core.ShapedArray(
shape=x_aval.shape if is_2x else (1,), dtype=out_dtype
)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out_aval, mu_aval, rsigma_aval, wkspace_aval
return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
wkspace_aval,
)
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd outer primitive abstract
"""
out_aval, mu_aval, rsigma_aval, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs)
return out_aval, mu_aval, rsigma_aval
(
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
_,
) = NormFwdPrimitive.abstract(*args, **kwargs)
return (
out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
mu_aval,
rsigma_aval,
)
@staticmethod
def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
def lowering(
ctx,
x,
scale,
gamma,
beta,
*,
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
LayerNorm fwd lowering rules
"""
x_aval, gamma_aval, beta_aval = ctx.avals_in
assert gamma_aval.dtype == beta_aval.dtype
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
del out_dtype, scale_dtype, scale_shapes, is_outer
x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
if norm_type == NVTE_Norm_Type.LayerNorm:
assert gamma_aval.dtype == beta_aval.dtype
b_type = ir.RankedTensorType(beta.type)
b_shape = b_type.shape
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name)(
return ffi.ffi_lowering(NormFwdPrimitive.name)(
ctx,
x,
scale,
gamma,
beta,
norm_type=norm_type.value,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
epsilon=epsilon,
sm_margin=sm_margin,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
)
else:
# Output shape is same as the input shape, but the output type is same as the weight type.
# See ln_api.cpp
output_type = g_type.element_type
ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get()
out_shape = x_shape
hidden_size = reduce(operator.mul, g_shape)
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [x, gamma, beta]
operand_shapes = [x_shape, g_shape, b_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
@staticmethod
def impl(
x,
scale,
gamma,
beta,
norm_type,
zero_centered_gamma,
epsilon,
sm_margin,
)
out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(x, gamma, beta, zero_centered_gamma, epsilon):
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
to describe implementation
"""
assert LayerNormFwdPrimitive.inner_primitive is not None
out, mu, rsigma, _ = LayerNormFwdPrimitive.inner_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
return out, mu, rsigma
del is_outer
assert NormFwdPrimitive.inner_primitive is not None
(
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
mu,
rsigma,
_,
) = NormFwdPrimitive.inner_primitive.bind(
x,
scale,
gamma,
beta,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x.shape, is_padded=False
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv = scale_inv.flatten()[
: reduce(operator.mul, rowwise_scale_inv_shape)
].reshape(rowwise_scale_inv_shape)
if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape)
].reshape(colwise_scale_inv_shape)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
mu,
rsigma,
) # Exclude wkspace
@staticmethod
def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
def batcher(
batched_args,
batch_dims,
*,
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
to describe batch rules for vmap
"""
del is_outer
check_valid_batch_dims(batch_dims)
assert LayerNormFwdPrimitive.outer_primitive is not None
x, gamma, beta = batched_args
x_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, x_bdim
assert NormFwdPrimitive.outer_primitive is not None
x, scale, gamma, beta = batched_args
x_bdim, scale_bdim, _, _ = batch_dims
out_bdims = (
x_bdim, # rowwise output
scale_bdim, # rowwise scale_inv
x_bdim, # colwise output
scale_bdim, # colwise scale_inv
scale_bdim, # amax
x_bdim, # mu
x_bdim, # rsigma
)
return (
LayerNormFwdPrimitive.outer_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
NormFwdPrimitive.outer_primitive.bind(
scale,
x,
gamma,
beta,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
del zero_centered_gamma, epsilon, result_infos
def infer_sharding_from_operands(
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, scale_shapes, is_outer
x_spec = get_padded_spec(arg_infos[0])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! "
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
return (out_sharding, mu_sharding, rsigma_sharding)
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out"
)
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
)
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
output = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding, # rowwise
scale_inv_sharding, # colwise
amax_sharding,
mu_sharding,
rsigma_sharding,
)
return output
@staticmethod
def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec, g_spec, b_spec = map(get_padded_spec, arg_infos)
def partition(
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[2])
b_spec = get_padded_spec(arg_infos[3])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! "
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter gamma "
f"{NormFwdPrimitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
if b_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdPrimitive.name} does not support sharding of parameter beta "
f"{NormFwdPrimitive.name} does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (x_sharding, g_sharding, b_sharding)
out_shardings = (out_sharding, mu_sharding, rsigma_sharding)
impl = partial(
LayerNormFwdPrimitive.impl, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
x_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x"
)
return mesh, impl, out_shardings, arg_shardings
register_primitive(LayerNormFwdPrimitive)
def _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps):
"""
JAX native layernorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
gamma += 1.0
return jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
def _jax_rmsnorm(x, gamma, zero_centered_gamma, eps):
"""
JAX native rmsnorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
normed_input = x_ * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
gamma += 1.0
return jnp.asarray(normed_input * gamma).astype(x.dtype)
def _jax_layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, eps):
"""
JAX native layernorm fp8 implementation
"""
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + eps)
normed_input = (x_ - mean) * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma + beta
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1), updated_amax
g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma")
b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta")
out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out")
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]),
desc="NormFwdPrimitive.rsigma",
)
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
scale_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale"
)
scale_inv_sharding = scale_sharding.duplicate_with_new_description(
"NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
)
arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding)
out_shardings = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding, # rowwise
scale_inv_sharding, # colwise
amax_sharding,
mu_sharding,
rsigma_sharding,
)
def sharded_impl(x, scale, gamma, beta):
# expect tp and dp giving same shape, or tp being same shape as global
(
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
local_amax,
local_mu,
local_rsigma,
) = NormFwdPrimitive.impl(
x,
scale,
gamma,
beta,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True,
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
return (
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
global_updated_amax,
local_mu,
local_rsigma,
)
def _jax_rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps):
"""
JAX native rmsnorm fp8 implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + eps)
normed_input = x_ * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
def layernorm_fwd(
x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float
):
"""
Wrapper for TE layernorm fwd
"""
if not LayerNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
mu = jnp.mean(x_, axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon)
return (
_jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon),
jnp.squeeze(mu, axis=-1),
jnp.squeeze(rsigma, axis=-1),
)
return LayerNormFwdPrimitive.outer_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
register_primitive(NormFwdPrimitive)
class LayerNormBwdPrimitive(BasePrimitive):
class NormBwdPrimitive(BasePrimitive):
"""
Layer Normalization Backward Primitive
"""
name = "te_layernorm_backward"
name = "te_norm_backward_ffi"
multiple_results = True
impl_static_args = (5, 6) # zero_centered_gamma, epsilon
impl_static_args = (5, 6) # norm_type, zero_centered_gamma
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):
def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_centered_gamma):
"""
Layernorm bwd inner primitive abstract
bwd inner primitive abstract
"""
w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)
assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
assert dz_aval.shape == x_aval.shape
if norm_type == NVTE_Norm_Type.LayerNorm:
mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
assert mu_dtype == rsigma_dtype == jnp.float32
dx_aval = dz_aval
dgamma_aval = dbeta_aval = gamma_aval
if norm_type != NVTE_Norm_Type.LayerNorm:
dbeta_aval = dbeta_aval.update(shape=(1,))
(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
(wkspace_info,) = transformer_engine_jax.get_norm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
True,
kwargs["zero_centered_gamma"],
kwargs["epsilon"],
norm_type,
zero_centered_gamma,
get_backward_sm_margin(),
)
wkspace_aval = dx_aval.update(
......@@ -395,17 +595,14 @@ class LayerNormBwdPrimitive(BasePrimitive):
"""
LayerNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, dbeta_aval, _ = LayerNormBwdPrimitive.abstract(*args, **kwargs)
dx_aval, dgamma_aval, dbeta_aval, _ = NormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval, dbeta_aval
@staticmethod
def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma):
"""
Layernorm bwd lowering rules
bwd lowering rules
"""
_, x_aval, _, _, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
b_type = ir.RankedTensorType(gamma.type)
......@@ -413,1124 +610,644 @@ class LayerNormBwdPrimitive(BasePrimitive):
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_backward_ffi"
sm_margin = get_backward_sm_margin()
out = ffi.ffi_lowering(name)(
return ffi.ffi_lowering(NormBwdPrimitive.name)(
ctx,
dz,
x,
mu,
rsigma,
gamma,
norm_type=norm_type.value,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
dz_shape = ir.RankedTensorType(dz.type).shape
mu_shape = ir.RankedTensorType(mu.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
operands = [dz, mu, rsigma, x, gamma]
operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_backward_sm_margin()
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
zero_centered_gamma,
epsilon,
sm_margin,
)
out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon):
assert LayerNormBwdPrimitive.inner_primitive is not None
dx, dgamma, dbeta, _ = LayerNormBwdPrimitive.inner_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma):
assert NormBwdPrimitive.inner_primitive is not None
dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind(
dz, x, mu, rsigma, gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma
)
return dx, dgamma, dbeta
@staticmethod
def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma):
check_valid_batch_dims(batch_dims)
assert LayerNormBwdPrimitive.outer_primitive is not None
assert NormBwdPrimitive.outer_primitive is not None
dz, x, mu, rsigma, gamma = batched_args
_, x_bdim, _, _, gamma_bdim = batch_dims
out_bdims = x_bdim, gamma_bdim, gamma_bdim
return (
LayerNormBwdPrimitive.outer_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
NormBwdPrimitive.outer_primitive.bind(
dz,
x,
mu,
rsigma,
gamma,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
del zero_centered_gamma, epsilon, result_infos
def infer_sharding_from_operands(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos):
del norm_type, zero_centered_gamma, result_infos
x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! "
f"Does not support to shard hidden dim in {NormBwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients "
"of gamma and beta of Layernorm "
f"{NormBwdPrimitive.name} does not support sharding of gradients "
"of gamma and beta of "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
dx_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormBwdPrimitive.dx"
)
dgamma_sharding = dbeta_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormBwdPrimitive.dgamma"
)
return dx_sharding, dgamma_sharding, dbeta_sharding
@staticmethod
def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
def partition(norm_type, zero_centered_gamma, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! "
f"Does not support to shard hidden dim in {NormBwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_b_spec = get_padded_spec(arg_infos[4])
if g_b_spec[-1] is not None:
warnings.warn(
f"{LayerNormBwdPrimitive.name} does not support sharding of gradients "
"of gamma and beta of Layernorm "
f"{NormBwdPrimitive.name} does not support sharding of gradients "
"of gamma and beta of "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
dx_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormBwdPrimitive.dx"
)
dgamma_sharding = dbeta_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormBwdPrimitive.dgamma"
)
out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(None)))
rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1]), desc="NormBwdPrimitive.rsigma"
)
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormBwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormBwdPrimitive.mu")
arg_shardings = (
*x_shardings,
mu_sharding,
rsigma_sharding,
NamedSharding(mesh, PartitionSpec(None), desc="NormBwdPrimitive.gamma"),
)
def sharded_impl(dz, x, mu, rsigma, gamma):
local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
local_dx, local_dgamma, local_dbeta = NormBwdPrimitive.impl(
dz,
x,
mu,
rsigma,
gamma,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
if norm_type == NVTE_Norm_Type.LayerNorm:
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh)
else:
global_dbeta = local_dbeta
return local_dx, global_dgamma, global_dbeta
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(LayerNormBwdPrimitive)
register_primitive(NormBwdPrimitive)
def layernorm_bwd(
dz: jnp.ndarray,
x: jnp.ndarray,
mu: jnp.ndarray,
rsigma: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
):
def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None):
"""
Wrapper for TE layernorm bwd
JAX native layernorm implementation
"""
if not LayerNormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon),
x,
gamma,
beta,
)
return vjp_func(dz)
return LayerNormBwdPrimitive.outer_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + epsilon)
normed_input = (x_ - mean) * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma + beta
if quantizer:
ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else:
ln_out = jnp.asarray(output).astype(x.dtype)
class RmsNormFwdPrimitive(BasePrimitive):
"""
RMS Normalization Forward Primitive
"""
name = "te_rmsnorm_forward"
multiple_results = True
impl_static_args = (2,) # epsilon
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, gamma_aval, **kwargs):
"""
RMSNorm fwd inner primitive abstract
"""
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
rsigama_dtype = jnp.float32
out_aval = x_aval
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(x_aval.dtype), # out te_dtype (same as input for Fp16/Bf16)
False,
False,
kwargs["epsilon"],
get_forward_sm_margin(),
)
wkspace_aval = out_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out_aval, rsigma_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm fwd outer primitive abstract
"""
out_aval, rsigma_aval, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval
@staticmethod
def lowering(ctx, x, gamma, *, epsilon):
"""
RMSNorm fwd lowering rules
"""
if is_ffi_enabled():
name = "te_rmsnorm_forward_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
rsigma_element_type = ir.F32Type.get()
out_shape = x_shape
hidden_size = reduce(operator.mul, g_shape)
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type),
ir.RankedTensorType.get(batch_shape, rsigma_element_type),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [x, gamma]
operand_shapes = [x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(x, gamma, epsilon):
"""
to describe implementation
"""
assert RmsNormFwdPrimitive.inner_primitive is not None
out, rsigma, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
return out, rsigma
@staticmethod
def batcher(batched_args, batch_dims, *, epsilon):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims(batch_dims)
assert RmsNormFwdPrimitive.outer_primitive is not None
x, gamma = batched_args
x_bdim, _ = batch_dims
out_bdims = x_bdim, x_bdim
return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon), out_bdims
@staticmethod
def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
del epsilon, result_infos
x_spec = get_padded_spec(arg_infos[0])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
return (out_sharding, rsigma_sharding)
@staticmethod
def partition(epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec, g_spec = map(get_padded_spec, arg_infos)
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (x_sharding, g_sharding)
out_shardings = (out_sharding, rsigma_sharding)
impl = partial(RmsNormFwdPrimitive.impl, epsilon=epsilon)
return mesh, impl, out_shardings, arg_shardings
register_primitive(RmsNormFwdPrimitive)
def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
"""
Wrapper for TE rmsnorm fwd
"""
if not RmsNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + epsilon)
return _jax_rmsnorm(x, gamma, zero_centered_gamma=False, eps=epsilon), jnp.squeeze(
rsigma, axis=-1
)
return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)
class RmsNormBwdPrimitive(BasePrimitive):
"""
RMS Normalization Backward Primitive
"""
name = "te_rmsnorm_backward"
multiple_results = True
impl_static_args = (4,) # epsilon
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs):
"""
RMSNorm bwd inner primitive abstract
"""
w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)
assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
assert dz_aval.shape == x_aval.shape
assert rsigma_aval.shape == x_aval.shape[:-1]
assert rsigma_dtype == jnp.float32
dx_aval = dz_aval
dgamma_aval = gamma_aval
(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
False,
False,
kwargs["epsilon"],
get_backward_sm_margin(),
)
wkspace_aval = dx_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return dx_aval, dgamma_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm bwd outer primitive abstract
"""
dx_aval, dgamma_aval, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs)
return dx_aval, dgamma_aval
@staticmethod
def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
"""
RMSNorm bwd lowering rules
"""
if is_ffi_enabled():
name = "te_rmsnorm_backward_ffi"
sm_margin = get_backward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
_, x_aval, _, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
dz_shape = ir.RankedTensorType(dz.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [dz, rsigma, x, gamma]
operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_backward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(dz, x, rsigma, gamma, epsilon):
assert RmsNormBwdPrimitive.inner_primitive is not None
dx, dgamma, _ = RmsNormBwdPrimitive.inner_primitive.bind(
dz, x, rsigma, gamma, epsilon=epsilon
)
return dx, dgamma
@staticmethod
def batcher(batched_args, batch_dims, *, epsilon):
check_valid_batch_dims(batch_dims)
assert RmsNormBwdPrimitive.outer_primitive is not None
dz, x, rsigma, gamma = batched_args
_, x_bdim, _, gamma_bdim = batch_dims
out_bdims = x_bdim, gamma_bdim
return (
RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
del epsilon, result_infos
x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
return dx_sharding, dgamma_sharding
@staticmethod
def partition(epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_spec = get_padded_spec(arg_infos[3])
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
out_shardings = dx_sharding, dgamma_sharding
x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding.
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(None)))
def sharded_impl(dz, x, rsigma, gamma):
local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
return local_dx, global_dgamma
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(RmsNormBwdPrimitive)
def rmsnorm_bwd(
dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray, epsilon: float
):
"""
Wrapper for TE layernorm bwd
"""
if not RmsNormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_rmsnorm, zero_centered_gamma=False, eps=epsilon), x, gamma
)
return vjp_func(dz)
return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
class LayerNormFwdFp8Primitive(BasePrimitive):
"""
Layer Normalization Forward FP8 Primitive
"""
name = "te_layernorm_forward_fp8"
multiple_results = True
impl_static_args = (6, 7, 8) # out_type, zero_centered_gamma, epsilon
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
gamma_aval,
beta_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
zero_centered_gamma,
epsilon,
):
"""
LayerNorm fwd (fp8 out) inner primitive abstract
"""
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
mu_rsigama_dtype = jnp.float32
assert gamma_aval.size == beta_aval.size
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
jax_dtype_to_te_dtype(x_aval.dtype), # in type
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight type
jax_dtype_to_te_dtype(out_dtype),
True,
zero_centered_gamma,
epsilon,
get_forward_sm_margin(),
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
LayerNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, mu_aval, rsigma_aval, updated_amax_aval, _ = LayerNormFwdFp8Primitive.abstract(
*args, **kwargs
)
return out_aval, mu_aval, rsigma_aval, updated_amax_aval
@staticmethod
def lowering(
ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma, epsilon
):
"""
LayerNorm fwd (fp8 out) lowering rules
"""
x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gamma_aval.dtype == beta_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
b_type = ir.RankedTensorType(beta.type)
b_shape = b_type.shape
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name, operand_output_aliases={3: 3})(
ctx,
x,
gamma,
beta,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get()
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_shape = x_shape
hidden_size = reduce(operator.mul, g_shape)
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [
x_shape,
g_shape,
b_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
zero_centered_gamma,
epsilon,
sm_margin,
)
def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
"""
JAX native rmsnorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + epsilon)
normed_input = x_ * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma
out = custom_caller(
LayerNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={3: 3}
)
if quantizer:
ln_out = quantizer.quantize(output, dq_dtype=x.dtype)
else:
ln_out = jnp.asarray(output).astype(x.dtype)
return out
return ln_out, jnp.squeeze(rsigma, axis=-1)
@staticmethod
def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, epsilon):
"""
to describe implementation
"""
assert LayerNormFwdFp8Primitive.inner_primitive is not None
out, mu, rsigma, updated_amax, _ = LayerNormFwdFp8Primitive.inner_primitive.bind(
def layernorm_fwd(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
quantizer: Optional[Quantizer],
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]:
"""Layer normalization forward pass with optional quantization.
Args:
x: Input tensor to be normalized.
Shape: (..., K) where K is the hidden size.
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
A tuple containing:
- If quantizer is None:
The normalized input tensor. Shape: (..., K)
If quantizer is provided:
A ScaledTensor containing the quantized normalized input.
- Mean of the input tensor. Shape: (..., 1)
- Reciprocal of the standard deviation of the input tensor. Shape: (..., 1)
"""
if not NormFwdPrimitive.enabled():
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = (
quantizer.scale
if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32)
)
if quantizer is None:
output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind(
x,
scale,
gamma,
beta,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
norm_type=NVTE_Norm_Type.LayerNorm,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
return out, mu, rsigma, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, zero_centered_gamma, epsilon):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims(batch_dims)
assert LayerNormFwdFp8Primitive.outer_primitive is not None
x, gamma, beta, amax, scale, scale_inv = batched_args
x_bdim, _, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return (
LayerNormFwdFp8Primitive.outer_primitive.bind(
out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((1,), (1,)),
is_outer=True,
)
return output, mu, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
is_2x2x = False
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
mu,
rsigma,
) = NormFwdPrimitive.outer_primitive.bind(
x,
scale,
gamma,
beta,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
norm_type=NVTE_Norm_Type.LayerNorm,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos
):
del out_dtype, zero_centered_gamma, epsilon, result_infos
x_spec = get_padded_spec(arg_infos[0])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_outer=True,
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
colwise_scale_inv = rowwise_scale_inv
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False
)
rowwise_scale_inv = rowwise_scale_inv.flatten()[
: reduce(operator.mul, rowwise_unpadded_shape)
].reshape(rowwise_unpadded_shape)
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_unpadded_shape)
].reshape(colwise_unpadded_shape)
scaled_tensor = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
)
return scaled_tensor, mu, rsigma
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
return (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[1])
b_spec = get_padded_spec(arg_infos[2])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {LayerNormFwdFp8Primitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
if b_spec[-1] is not None:
warnings.warn(
f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
mu_sharding = rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1])
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
fp8_meta_sharding = amax_sharding
arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3
out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)
def sharded_impl(x, gamma, beta, amax, scale, scale_inv):
local_x, local_mu, local_rsigma, local_amax = LayerNormFwdFp8Primitive.impl(
def layernorm_bwd(
dz: jnp.ndarray,
x: jnp.ndarray,
mu: jnp.ndarray,
rsigma: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
):
"""Layer normalization backward pass.
Args:
dz: Gradient of the output with respect to the normalized output.
Shape: (..., K) where K is the hidden size.
x: Input tensor that was normalized in the forward pass.
Shape: (..., K)
mu: Mean of the input tensor from the forward pass.
Shape: (..., 1)
rsigma: Reciprocal of the standard deviation from the forward pass.
Shape: (..., 1)
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
Returns:
A tuple containing:
- Gradient of the input tensor.
Shape: (..., K)
- Gradient of the scale parameter (gamma).
Shape: (K,)
- Gradient of the bias parameter (beta).
Shape: (K,)
"""
if not NormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_layernorm, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon),
x,
gamma,
beta,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
)
mu_empty = jnp.zeros(mu.shape, mu.dtype)
rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
return vjp_func((dz, mu_empty, rsigma_empty))
return NormBwdPrimitive.outer_primitive.bind(
dz,
x,
mu,
rsigma,
gamma,
norm_type=NVTE_Norm_Type.LayerNorm,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, local_mu, local_rsigma, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(LayerNormFwdFp8Primitive)
def layernorm_fwd_fp8(
def rmsnorm_fwd(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: jnp.dtype,
zero_centered_gamma: bool,
epsilon: float,
):
"""
Wrapper for TE layernorm fwd (fp8 out)
"""
if not LayerNormFwdFp8Primitive.enabled():
return _jax_layernorm_fp8(
quantizer: Optional[Quantizer],
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]:
"""Root mean square normalization forward pass with optional quantization.
Args:
x: Input tensor to be normalized.
Shape: (..., K) where K is the hidden size.
gamma: Scale parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
A tuple containing:
- If quantizer is None:
The normalized input tensor.
Shape: (..., K)
If quantizer is provided:
A ScaledTensor containing the quantized normalized input.
- Reciprocal of the root mean square of the input tensor.
Shape: (..., 1)
"""
if not NormFwdPrimitive.enabled():
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = (
quantizer.scale
if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32)
)
beta = jnp.ones((1,), dtype=jnp.float32)
if quantizer is None:
output, _, _, _, _, _, rsigma = NormFwdPrimitive.outer_primitive.bind(
x,
scale,
gamma,
beta,
scale,
amax,
out_dtype=out_dtype,
norm_type=NVTE_Norm_Type.RMSNorm,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
)
return LayerNormFwdFp8Primitive.outer_primitive.bind(
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True,
)
return output, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
is_2x2x = False
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
_,
rsigma,
) = NormFwdPrimitive.outer_primitive.bind(
x,
scale,
gamma,
beta,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
norm_type=NVTE_Norm_Type.RMSNorm,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
class RmsNormFwdFp8Primitive(BasePrimitive):
"""
RMS Normalization Forward FP8 Primitive
"""
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_outer=True,
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
colwise_scale_inv = rowwise_scale_inv
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False
)
rowwise_scale_inv = rowwise_scale_inv.flatten()[
: reduce(operator.mul, rowwise_unpadded_shape)
].reshape(rowwise_unpadded_shape)
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_unpadded_shape)
].reshape(colwise_unpadded_shape)
scaled_tensor = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
)
return scaled_tensor, rsigma
name = "te_rmsnorm_forward_fp8"
multiple_results = True
impl_static_args = (5, 6) # out_dtype, epsilon
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon):
"""
RMSNorm fwd (fp8 out) inner primitive abstract
"""
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
hidden_size = gamma_aval.size
assert x_aval.size % hidden_size == 0
rsigama_dtype = jnp.float32
(wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
x_aval.size // hidden_size, # batch_size
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype
jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype
jax_dtype_to_te_dtype(out_dtype), # out te_dtype
False,
False,
epsilon,
get_forward_sm_margin(),
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
def rmsnorm_bwd(
dz: jnp.ndarray,
x: jnp.ndarray,
rsigma: jnp.ndarray,
gamma: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
):
"""Root mean square normalization backward pass.
Args:
dz: Gradient of the output with respect to the normalized output.
Shape: (..., K) where K is the hidden size.
x: Input tensor that was normalized in the forward pass.
Shape: (..., K)
rsigma: Reciprocal of the root mean square from the forward pass.
Shape: (..., 1)
gamma: Scale parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
Returns:
A tuple containing:
- Gradient of the input tensor.
Shape: (..., K)
- Gradient of the scale parameter (gamma).
Shape: (K,)
"""
if not NormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_rmsnorm, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon),
x,
gamma,
)
return out_aval, rsigma_aval, amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
RMSNorm fwd (fp8 out) outer primitive abstract
"""
out_aval, rsigma_aval, amax_aval, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs)
return out_aval, rsigma_aval, amax_aval
@staticmethod
def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
"""
RMSNorm fwd (fp8 out) lowering rules
"""
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
if is_ffi_enabled():
name = "te_rmsnorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx,
rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype)
return vjp_func((dz, rsigma_empty))
mu = jnp.empty(())
dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind(
dz,
x,
mu,
rsigma,
gamma,
amax,
scale,
scale_inv,
norm_type=NVTE_Norm_Type.RMSNorm,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_rsigma_dtype = ir.F32Type.get()
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_shape = x_shape
hidden_size = reduce(operator.mul, g_shape)
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
return (dx, dgamma)
out = custom_caller(
RmsNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={2: 2}
)
return out
@staticmethod
def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon):
"""
to describe implementation
"""
assert RmsNormFwdFp8Primitive.inner_primitive is not None
out, rsigma, amax, _ = RmsNormFwdFp8Primitive.inner_primitive.bind(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
return out, rsigma, amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, epsilon):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims(batch_dims)
assert RmsNormFwdFp8Primitive.outer_primitive is not None
x, gamma, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim
return (
RmsNormFwdFp8Primitive.outer_primitive.bind(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_infos):
del out_dtype, epsilon, result_infos
x_spec = get_padded_spec(arg_infos[0])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, rsigma_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
g_spec = get_padded_spec(arg_infos[1])
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if g_spec[-1] is not None:
warnings.warn(
f"{RmsNormFwdFp8Primitive.name} does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
out_sharding = x_sharding
rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
fp8_meta_sharding = amax_sharding
arg_shardings = (x_sharding, g_sharding) + (fp8_meta_sharding,) * 3
out_shardings = (out_sharding, rsigma_sharding, amax_sharding)
def sharded_impl(x, gamma, amax, scale, scale_inv):
local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, local_rsigma, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
def normalization_fwd(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
norm_type: str,
quantizer: Optional[Quantizer],
):
"""Common wrapper for normalization forward pass.
Args:
x: Input tensor to be normalized.
Shape: (..., K) where K is the hidden size.
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
norm_type: Type of normalization to apply. Must be one of:
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
A tuple containing:
- If quantizer is None:
The normalized input tensor.
Shape: (..., K)
If quantizer is provided:
A ScaledTensor containing the quantized normalized input.
- Mean of the input tensor (None for RMSNorm).
Shape: (..., 1)
- Reciprocal of the standard deviation (or root mean square for RMSNorm).
Shape: (..., 1)
Note:
zero_centered_gamma is not supported if norm_type is 'rmsnorm'.
"""
if norm_type == "layernorm":
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
elif norm_type == "rmsnorm":
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer)
mu = None
else:
raise ValueError(f"{norm_type=} is not supported.")
register_primitive(RmsNormFwdFp8Primitive)
return output, mu, rsigma
def rmsnorm_fwd_fp8(
def normalization_bwd(
dz: jnp.ndarray,
x: jnp.ndarray,
mu: jnp.ndarray,
rsigma: jnp.ndarray,
gamma: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: jnp.dtype,
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
norm_type: str,
):
"""
Wrapper for TE rmsnorm fwd (fp8 out)
"""
if not RmsNormFwdFp8Primitive.enabled():
return _jax_rmsnorm_fp8(
x, gamma, scale, amax, out_dtype=out_dtype, zero_centered_gamma=False, eps=epsilon
)
return RmsNormFwdFp8Primitive.outer_primitive.bind(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
"""Common wrapper for normalization backward pass.
Args:
dz: Gradient of the output with respect to the normalized output.
Shape: (..., K) where K is the hidden size.
x: Input tensor that was normalized in the forward pass.
Shape: (..., K)
mu: Mean of the input tensor from the forward pass (None for RMSNorm).
Shape: (..., 1)
rsigma: Reciprocal of the standard deviation (or root mean square) from the forward pass.
Shape: (..., 1)
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
norm_type: Type of normalization used in the forward pass. Must be one of:
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
Returns:
A tuple containing:
- Gradient of the input tensor.
Shape: (..., K)
- Gradient of the scale parameter (gamma).
Shape: (K,)
- Gradient of the bias parameter (beta) (None for RMSNorm).
Shape: (K,)
Note:
zero_centered_gamma is not supported if norm_type is 'rmsnorm'.
"""
if norm_type == "layernorm":
dx, dgamma, dbeta = layernorm_bwd(
dz, x, mu, rsigma, gamma, beta, zero_centered_gamma, epsilon
)
elif norm_type == "rmsnorm":
assert (
not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
dx, dgamma = rmsnorm_bwd(dz, x, rsigma, gamma, zero_centered_gamma, epsilon)
dbeta = None
else:
raise ValueError(f"{norm_type=} is not supported.")
return dx, dgamma, dbeta
......@@ -2,28 +2,29 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
from typing import Tuple
from typing import Tuple, Optional
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.sharding import PartitionSpec
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
get_padded_spec,
check_valid_batch_dims,
te_dtype_to_jax_dtype,
jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
is_ffi_enabled,
multidim_transpose,
should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding,
)
from ..sharding import all_reduce_max_along_all_axes_except_PP
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeAxis, DelayedScaleQuantizer, ScalingMode
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
......@@ -31,166 +32,591 @@ else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["cast_fp8"]
__all__ = ["quantize", "quantize_dbias"]
def _jax_quantize(x, scale, q_dtype):
class DBiasQuantizePrimitive(BasePrimitive):
"""
Quantize with scale
Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
"""
compute_dtype = scale.dtype
dtype_max = (jnp.finfo(q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * scale
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_scaled_x.astype(q_dtype)
def _jax_cast_fp8(inputs, scale, amax, out_dtype):
"""
JAX native fp8 casting implementation
"""
casted_output = _jax_quantize(inputs, scale, q_dtype=out_dtype)
updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype))
return casted_output, updated_amax
class CastFP8Primitive(BasePrimitive):
"""
Cast Primitive
"""
name = "te_quantize"
name = "te_dbias_quantize_ffi"
multiple_results = True
impl_static_args = (4,)
impl_static_args = (
2,
3,
4,
5,
6,
7,
8,
) # out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
def abstract(
x_aval,
scale_aval,
*,
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
te_cast abstract
te_dbias_quantize_p abstract
"""
del scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
assert scale_aval is None or scale_aval.dtype == jnp.float32
casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
rowwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
return casted_x_aval, updated_amax_aval
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
rowwise_out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
@staticmethod
def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
"""
te_cast lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_quantize_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
ctx, x, amax, scale, scale_inv
)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
out_types = [
ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
opaque = transformer_engine_jax.pack_common_descriptor(
ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
t_shape = multidim_transpose(x_aval.shape)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
# Don't transpose output for MXFP8
t_shape = x_aval.shape
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
out = custom_caller(
CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
if is_dbias:
gi_hidden_size = x_aval.shape[-1]
dbias_shape = (gi_hidden_size,)
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
(wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out
return (
rowwise_out_aval,
colwise_out_aval,
scale_inv_aval,
colwise_scale_inv_aval,
updated_amax_aval,
dbias_aval,
wkspace_aval,
)
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype):
def outer_abstract(*args, **kwargs):
"""
te_cast implementation
te_dbias_quantize_p outer primitive abstract
"""
assert CastFP8Primitive.inner_primitive is not None
casted_x, updated_amax = CastFP8Primitive.inner_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype
(
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
_,
) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
@staticmethod
def lowering(
ctx,
x,
scale,
*,
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
te_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, is_outer
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
return ffi.ffi_lowering(DBiasQuantizePrimitive.name)(
ctx,
x,
scale,
scaling_mode=scaling_mode,
q_axis=q_axis,
is_dbias=is_dbias,
)
return casted_x, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype):
check_valid_batch_dims(batch_dims)
assert CastFP8Primitive.outer_primitive is not None
def impl(
x,
scale,
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
te_dbias_quantize_p implementation
"""
del is_outer
assert DBiasQuantizePrimitive.inner_primitive is not None
(
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
_,
) = DBiasQuantizePrimitive.inner_primitive.bind(
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
is_outer=False,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) # Exclude wkspace
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, *_ = batch_dims
@staticmethod
def batcher(
batched_args,
batch_dims,
*,
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
to describe batch rules for vmap
"""
del is_outer
check_valid_batch_dims(batch_dims)
assert DBiasQuantizePrimitive.outer_primitive is not None
x, scale = batched_args
x_bdim, scale_bdim = batch_dims
amax_bdim = scale_bdim
out_bdims = x_bdim, amax_bdim
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
return (
CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype),
DBiasQuantizePrimitive.outer_primitive.bind(
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
del out_dtype, result_infos
def infer_sharding_from_operands(
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
arg_infos,
result_infos,
):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_dbias, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (casted_x_sharding, amax_sharding)
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]),
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec)
else:
colwise_out_spec = x_spec
else:
colwise_out_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
)
scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])),
desc="DBiasQuantizePrimitive.scale_inv",
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DBiasQuantizePrimitive.dbias_sharding",
)
return (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_sharding,
)
@staticmethod
def partition(out_dtype, mesh, arg_infos, result_infos):
del result_infos
def partition(
out_dtype,
scaling_mode,
q_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]),
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec)
else:
colwise_out_spec = x_spec
else:
colwise_out_spec = (None,)
colwise_out_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
)
scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])),
desc="DBiasQuantizePrimitive.scale_inv",
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DBiasQuantizePrimitive.dbias_sharding",
)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, amax_sharding)
out_shardings = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_sharding,
)
def sharded_impl(x, amax, scale, scale_inv):
local_cx, local_updated_amax = CastFP8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype
def sharded_impl(x, scale):
(
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
local_amax,
local_dbias,
) = DBiasQuantizePrimitive.impl(
x,
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
is_outer=True,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
return local_cx, global_updated_amax
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
if is_dbias:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
else:
global_dbias = local_dbias
return (
local_x,
local_colwise_x,
local_scale_inv,
local_colwise_scale_inv,
global_updated_amax,
global_dbias,
)
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(CastFP8Primitive)
register_primitive(DBiasQuantizePrimitive)
def cast_fp8(
def _jax_quantize(x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None):
if quantizer is None:
return x
return quantizer.quantize(x, dq_dtype=dq_dtype)
def _jax_dbias(dx: jnp.ndarray):
dbias = jnp.sum(
dx,
axis=tuple(range(dx.ndim - 1)),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias
def _jax_quantize_dbias(
x,
quantizer: Quantizer = None,
dq_dtype: Optional[jnp.dtype] = None,
):
if quantizer is None:
return x, None
return quantizer.quantize(x, dq_dtype=dq_dtype), _jax_dbias(x)
def _jax_dbias(
dx: jnp.ndarray,
):
dbias = jnp.sum(
dx.astype(jnp.float32),
axis=tuple(range(dx.ndim - 1)),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias.astype(dx.dtype)
def _quantize_impl(
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
quantizer: Quantizer,
is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
Return FP8 tensor
"""
if not CastFP8Primitive.enabled():
return _jax_cast_fp8(x, scale, amax, out_dtype=out_dtype)
return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)
assert (dq_dtype is None) or (
quantizer is not None
), "quantizer must be provided if dq_dtype is provided"
if not DBiasQuantizePrimitive.enabled():
if is_dbias:
return _jax_quantize_dbias(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
)
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None
# TE/common doesn't support colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if is_dbias:
return _jax_quantize_dbias(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
)
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None
scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_impl(
x=x,
is_dbias=False,
quantizer=quantizer,
dq_dtype=dq_dtype,
)
dbias = _jax_dbias(x)
return out, dbias
if quantizer is None:
if is_dbias:
return x, _jax_dbias(x)
return x, None
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) = DBiasQuantizePrimitive.outer_primitive.bind(
x,
scale,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_axis=quantizer.q_axis.value,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_dbias=is_dbias,
is_outer=True,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
out = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=dq_dtype if dq_dtype is not None else x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
)
return out, dbias
# TODO(Phuong): do not expose dq_dtype to users
def quantize(
x: jnp.ndarray,
quantizer: Quantizer,
dq_dtype: Optional[jnp.dtype] = None,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
Args:
x: Input tensor to be quantized.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
Returns:
A ScaledTensor containing the quantized input tensor.
"""
out, _ = _quantize_impl(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
)
return out
# TODO(Phuong): do not expose dq_dtype to users
def quantize_dbias(
dz: jnp.ndarray,
quantizer: Quantizer,
is_dbias: bool = True,
dq_dtype: Optional[jnp.dtype] = None,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
Args:
dz: Input tensor to be quantized and used for bias gradient computation.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
is_dbias: If True, compute bias gradient. Defaults to True.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
Returns:
A tuple containing:
- A ScaledTensor containing the quantized input tensor.
The ScaledTensor includes both the quantized data and scaling factors.
- The bias gradient tensor.
Shape: (K,) or empty if is_dbias is False.
"""
return _quantize_impl(
dz,
quantizer=quantizer,
is_dbias=is_dbias,
dq_dtype=dq_dtype,
)
......@@ -11,14 +11,10 @@ from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
from .misc import get_padded_spec, check_valid_batch_dims
from ..softmax import SoftmaxType
if version.parse(jax.__version__) >= version.parse("0.5.0"):
......@@ -38,30 +34,6 @@ __all__ = [
]
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
if mask is not None:
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def is_softmax_kernel_available(
softmax_type: SoftmaxType,
batch: int,
......@@ -139,38 +111,7 @@ class SoftmaxPrimitive(BasePrimitive):
"""
softmax_forward lowering rules
"""
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor)
else:
(i_aval,) = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
pad_batch = batch
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits]
operand_shapes = [i_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(i_aval.dtype),
scale_factor,
)
out = custom_caller(name, args, opaque, False)
return out
return ffi.ffi_lowering(name)(ctx, logits, scale_factor=scale_factor)
@staticmethod
def forward_impl(primitive, logits, scale_factor):
......@@ -250,43 +191,7 @@ class SoftmaxPrimitive(BasePrimitive):
"""
softmax_backward lowering rules
"""
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor)
else:
dz_aval, _ = ctx.avals_in
dz_type = ir.RankedTensorType(dz.type)
dz_shape = dz_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, dz_shape[:-3])
pad_batch = batch # unused
heads = dz_shape[-3]
q_seqlen = dz_shape[-2]
k_seqlen = dz_shape[-1]
softmax_out_type = ir.RankedTensorType(softmax_out.type)
softmax_out_shape = softmax_out_type.shape
out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
operands = [dz, softmax_out]
operand_shapes = [dz_shape, softmax_out_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(dz_aval.dtype),
scale_factor,
)
out = custom_caller(name, args, opaque, False)
return out
return ffi.ffi_lowering(name)(ctx, dz, softmax_out, scale_factor=scale_factor)
@staticmethod
def backward_impl(primitive, dz, softmax_out, scale_factor):
......@@ -356,7 +261,7 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
Scaled Softmax Fwd Primitive
"""
name = "te_scaled_softmax_forward"
name = "te_scaled_softmax_forward_ffi"
multiple_results = False
impl_static_args = (1,) # scale_factor
inner_primitive = None
......@@ -429,22 +334,12 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledSoftmaxFwdPrimitive)
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
Scaled Softmax Bwd Primitive
"""
name = "te_scaled_softmax_backward"
name = "te_scaled_softmax_backward_ffi"
multiple_results = False
impl_static_args = (2,) # scale_factor
inner_primitive = None
......@@ -530,7 +425,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
Scaled Masked Softmax Fwd Primitive
"""
name = "te_scaled_masked_softmax_forward"
name = "te_scaled_masked_softmax_forward_ffi"
multiple_results = False
impl_static_args = (2,) # scale_factor
inner_primitive = None
......@@ -591,42 +486,10 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""
te_scaled_masked_softmax_forward lowering rules
"""
if is_ffi_enabled():
ffi_name = "te_scaled_masked_softmax_forward_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor)
else:
logits_aval, _ = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]
mask_type = ir.RankedTensorType(mask.type)
mask_shape = mask_type.shape
pad_batch = reduce(operator.mul, mask_shape[:-3])
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits, mask]
operand_shapes = [i_shape, mask_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(logits_aval.dtype),
scale_factor,
return ffi.ffi_lowering(ScaledMaskedSoftmaxFwdPrimitive.name)(
ctx, logits, mask, scale_factor=scale_factor
)
out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(logits, mask, scale_factor):
assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
......@@ -666,26 +529,12 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
def scaled_masked_softmax_fwd(
logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
"""
scaled_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor
)
class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
Scaled Masked Softmax Bwd Primitive
"""
name = "te_scaled_masked_softmax_backward"
name = "te_scaled_masked_softmax_backward_ffi"
multiple_results = False
impl_static_args = (2,) # scale_factor
inner_primitive = None
......@@ -712,12 +561,10 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
te_scaled_upper_triang_masked_backward lowering rules
"""
out = SoftmaxPrimitive.backward_lowering(
return SoftmaxPrimitive.backward_lowering(
ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
)
return out
@staticmethod
def impl(dz, softmax_out, scale_factor):
return SoftmaxPrimitive.backward_impl(
......@@ -753,33 +600,12 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
def scaled_masked_softmax_bwd(
dz: jnp.ndarray,
softmax_out: jnp.ndarray,
logits: jnp.ndarray,
mask: jnp.ndarray,
scale_factor: float,
) -> jnp.ndarray:
"""
scaled_masked_backward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
)
return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor
)
class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""
Scaled Upper Triang Masked Softmax Fwd Primitive
"""
name = "te_scaled_upper_triang_masked_softmax_forward"
name = "te_scaled_upper_triang_masked_softmax_forward_ffi"
multiple_results = False
impl_static_args = (1,) # scale_factor
inner_primitive = None
......@@ -860,24 +686,12 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_upper_triang_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor
)
class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
Scaled Upper Triang Masked Softmax Bwd Primitive
"""
name = "te_scaled_upper_triang_masked_softmax_backward"
name = "te_scaled_upper_triang_masked_softmax_backward_ffi"
multiple_results = False
impl_static_args = (2,) # scale_factor
inner_primitive = None
......@@ -904,7 +718,7 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
te_scaled_upper_triang_masked_backward lowering rules
"""
out = SoftmaxPrimitive.backward_lowering(
return SoftmaxPrimitive.backward_lowering(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
ctx,
dz,
......@@ -912,8 +726,6 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
scale_factor=scale_factor,
)
return out
@staticmethod
def impl(dz, softmax_out, scale_factor):
return SoftmaxPrimitive.backward_impl(
......@@ -953,6 +765,87 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
if mask is not None:
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
def scaled_masked_softmax_fwd(
logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
"""
scaled_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor
)
def scaled_masked_softmax_bwd(
dz: jnp.ndarray,
softmax_out: jnp.ndarray,
logits: jnp.ndarray,
mask: jnp.ndarray,
scale_factor: float,
) -> jnp.ndarray:
"""
scaled_masked_backward wrapper
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
)
return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
dz, softmax_out, scale_factor=scale_factor
)
def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_upper_triang_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor
)
def scaled_upper_triang_masked_softmax_bwd(
dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for transpose"""
import operator
from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
check_valid_batch_dims,
jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype,
get_padded_spec,
multidim_transpose,
normalize_axis_boundary,
is_ffi_enabled,
)
from .activation import ActivationEnum
from .activation import _jax_act_lu
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"transpose",
"cast_transpose",
"dbias_cast_transpose",
"dact_lu_dbias_cast_transpose",
"dgated_act_lu_cast_transpose",
]
def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary):
"""
JAX native transpose implementation
"""
axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary)
return jnp.transpose(inputs, axes=axes)
def _jax_cast_transpose(
inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
JAX native cast_transpose implementation
"""
casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype)
casted_transposed_output = _jax_transpose(
casted_output, static_axis_boundary, transpose_axis_boundary
)
return casted_output, casted_transposed_output, updated_amax
def _jax_dbias_cast_transpose(
dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
JAX native dbias_cast_transpose implementation
"""
casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
dz,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
dbias = jnp.sum(
dz,
axis=tuple(
range(
transpose_axis_boundary
if transpose_axis_boundary > 0
else transpose_axis_boundary + dz.ndim
)
),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return casted_dz, cast_transposed_dz, dbias, updated_amax
class TransposePrimitive(BasePrimitive):
"""
Transpose Primitive
"""
name = "te_transpose"
multiple_results = False
impl_static_args = (1, 2)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
"""
_transpose abstract
"""
transposed_x_shape = multidim_transpose(
x_aval.shape, static_axis_boundary, transpose_axis_boundary
)
xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)
return xt_aval
@staticmethod
def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
"""
_transpose cuda lowering
"""
x_aval = ctx.avals_in[0]
assert x_aval.dtype in [
jnp.float32,
jnp.float16,
jnp.bfloat16,
jnp.float8_e4m3fn,
jnp.float8_e5m2,
]
if is_ffi_enabled():
name = "te_transpose_ffi"
out = ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
if static_axis_boundary >= 0:
for i in range(static_axis_boundary + 1):
assert ir_x_shape[i] == 1
transposed_x_shape = multidim_transpose(
ir_x_shape, static_axis_boundary, transpose_axis_boundary
)
out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
operands = [x]
operand_shapes = [ir_x_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
contracted_x_shape = (
reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
)
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape, te_dtype, te_dtype
)
out = custom_caller(TransposePrimitive.name, args, opaque, False)
return out
@staticmethod
def impl(x, static_axis_boundary, transpose_axis_boundary):
"""
tcast_transpose implementation
"""
assert TransposePrimitive.inner_primitive is not None
transposed_x = TransposePrimitive.inner_primitive.bind(
x,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return transposed_x
@staticmethod
def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
check_valid_batch_dims(batch_dims)
assert TransposePrimitive.outer_primitive is not None
assert static_axis_boundary < 0
(x,) = batched_args
(x_bdim,) = batch_dims
# Minus batch dim.
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim
return (
TransposePrimitive.outer_primitive.bind(
x, static_axis_boundary=x_bdim, transpose_axis_boundary=transpose_axis_boundary
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
return transposed_x_sharding
@staticmethod
def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = transposed_x_sharding
impl = partial(
TransposePrimitive.impl,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return mesh, impl, out_shardings, arg_shardings
register_primitive(TransposePrimitive)
def transpose(
x: jnp.ndarray, static_axis_boundary: int, transpose_axis_boundary: int
) -> jnp.ndarray:
"""
transpose wrapper
"""
if not TransposePrimitive.enabled():
return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary)
return TransposePrimitive.outer_primitive.bind(
x,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
class CastTransposePrimitive(BasePrimitive):
"""
Cast Transpose Primitive
"""
name = "te_cast_transpose"
multiple_results = True
impl_static_args = (4, 5, 6)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
static_axis_boundary,
transpose_axis_boundary
):
"""
te_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
transposed_x_shape = multidim_transpose(
x_aval.shape, static_axis_boundary, transpose_axis_boundary
)
casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return casted_x_aval, casted_xt_aval, updated_amax_aval
@staticmethod
def lowering(
ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
te_cast_transpose_p lowering rules
"""
x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 2})(
ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
if static_axis_boundary >= 0:
for i in range(static_axis_boundary + 1):
assert ir_x_shape[i] == 1
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(
ir_x_shape, static_axis_boundary, transpose_axis_boundary
)
out_types = [
ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (
reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
)
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
out = custom_caller(
CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2}
)
return out
@staticmethod
def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
"""
te_cast_transpose implementation
"""
assert CastTransposePrimitive.inner_primitive is not None
casted_x, casted_transposed_x, updated_amax = CastTransposePrimitive.inner_primitive.bind(
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return casted_x, casted_transposed_x, updated_amax
@staticmethod
def batcher(
batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
):
check_valid_batch_dims(batch_dims)
assert CastTransposePrimitive.outer_primitive is not None
assert static_axis_boundary < 0
x, amax, scale, scale_inv = batched_args
x_bdim, amax_bdim, *_ = batch_dims
# Minus batch dim.
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = x_bdim, x_bdim, amax_bdim
return (
CastTransposePrimitive.outer_primitive.bind(
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
transpose_axis_boundary=transpose_axis_boundary,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
@staticmethod
def partition(
out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
def sharded_impl(x, amax, scale, scale_inv):
local_cx, local_cxt, local_updated_amax = CastTransposePrimitive.impl(
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
return local_cx, local_cxt, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(CastTransposePrimitive)
def cast_transpose(
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: jnp.dtype,
static_axis_boundary: int,
transpose_axis_boundary: int,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose wrapper
Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
"""
if not CastTransposePrimitive.enabled():
return _jax_cast_transpose(
x, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
)
return CastTransposePrimitive.outer_primitive.bind(
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
class DBiasCastTransposePrimitive(BasePrimitive):
"""
DBias Cast Transpose Primitive
"""
name = "te_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args = (4, 5, 6)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
dz_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
static_axis_boundary,
transpose_axis_boundary
):
"""
te_dbias_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:])
t_shape = multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
dbias_shape = (*dz_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
(wkspace_info,) = transformer_engine_jax.get_dbias_ct_workspace_sizes(
dz_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = dz_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dbias_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = DBiasCastTransposePrimitive.abstract(
*args, **kwargs
)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(
ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
te_dbias_cast_transpose_p lowering rules
"""
dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_dbias_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 3})(
ctx, dz, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
)
else:
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
contracted_dz_shape = (batch_size, ir_hidden_size)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_dz_shape = multidim_transpose(
ir_dz_shape, static_axis_boundary, transpose_axis_boundary
)
dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [dz, amax, scale, scale_inv]
operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_dz_shape,
wkspace_aval.shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
)
out = custom_caller(
DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3}
)
return out
@staticmethod
def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
"""
to describe implementation
"""
assert DBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(
batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
check_valid_batch_dims(batch_dims)
assert DBiasCastTransposePrimitive.outer_primitive is not None
dz, amax, scale, scale_inv = batched_args
dz_bdim, amax_bdim, _, _ = batch_dims
# Minus batch dim.
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
transpose_axis_boundary += 1 # Plus batch dim
out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
return (
DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=dz_bdim,
transpose_axis_boundary=transpose_axis_boundary,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del out_dtype, result_infos
x_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(
out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
):
del result_infos
x_spec = get_padded_spec(arg_infos[0])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
casted_x_sharding,
casted_transposed_x_sharding,
dbias_shaprding,
amax_sharding,
)
def sharded_impl(dz, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DBiasCastTransposePrimitive)
def dbias_cast_transpose(
dz: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
transpose_axis_boundary: int = -1,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dbias partial fusion wrapper
Return FP8(inputs), dbias
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
if not DBiasCastTransposePrimitive.enabled():
return _jax_dbias_cast_transpose(
dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
)
return DBiasCastTransposePrimitive.outer_primitive.bind(
dz,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
class DActLuDBiasCastTransposePrimitive(BasePrimitive):
"""
DActLu DBias Cast Transpose Primitive
"""
name = "te_dact_lu_dbias_cast_transpose"
multiple_results = True
# out_dtype, static_axis_boundary, act_enum
impl_static_args = (5, 6, 7)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
dz_aval,
x_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
static_axis_boundary,
act_enum
): # pylint: disable=unused-argument
"""
te_dact_lu_dbais_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
dbias_shape = (*x_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
(wkspace_info,) = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out, t_out, dbias, updated_amax_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dact_lu_dbais_cast_transpose_p outer abstract
"""
out, t_out, dbias, updated_amax_aval, _ = DActLuDBiasCastTransposePrimitive.abstract(
*args, **kwargs
)
return out, t_out, dbias, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
"""
te_dgated_act_lu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_dact_lu_dbias_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={2: 3})(
ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
)
else:
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
ir_hidden_szie = ir_dz_shape[-1]
contracted_x_shape = (x_batch_size, ir_hidden_szie)
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie)
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [
ir_dz_shape,
x_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_wk_descriptor(
contracted_x_shape,
wkspace_aval.shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
act_enum,
)
out = custom_caller(
DActLuDBiasCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 3},
)
return out
@staticmethod
def impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype,
static_axis_boundary,
act_enum,
):
"""
to describe implementation
"""
assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None
out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
return out, t_out, dbias, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
check_valid_batch_dims(batch_dims)
assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
return (
DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
act_enum=act_enum,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype,
static_axis_boundary,
act_enum,
mesh,
arg_infos,
result_infos,
):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)
@staticmethod
def partition(
out_dtype,
static_axis_boundary,
act_enum,
mesh,
arg_infos,
result_infos,
):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
dbias_shaprding = NamedSharding(
mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
casted_x_sharding,
casted_transposed_x_sharding,
dbias_shaprding,
amax_sharding,
)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_dbias, local_amax = (
DActLuDBiasCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DActLuDBiasCastTransposePrimitive)
def dact_lu_dbias_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose dact_lu and dbias fusion wrapper
Return FP8(dact_lu(inputs)), dbias
ONLY support non-gated activation type
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
if not DActLuDBiasCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz)
transpose_axis_boundary = -2
return _jax_dbias_cast_transpose(
dx, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
)
act_type_id = ActivationEnum[activation_type]
return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_type_id,
)
class DgatedActLuCastTransposePrimitive(BasePrimitive):
"""
Dgated ActLu Cast Transpose Primitive
"""
name = "te_dgated_act_lu_cast_transpose"
multiple_results = True
impl_static_args = (5, 6, 7) # out_dtype, static_axis_boundary, act_enum
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
dz_aval,
x_aval,
amax_aval,
scale_aval,
scale_inv_aval,
*,
out_dtype,
static_axis_boundary,
act_enum
): # pylint: disable=unused-argument
"""
te_dgated_act_lu_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
assert x_aval.shape[-2] == 2 # Linear + GeLU
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_hidden_szie = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
assert ir_hidden_szie == gi_hidden_size
t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
return out, t_out, updated_amax_aval
@staticmethod
def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
"""
te_dgated_act_lu_cast_transpose_p lowering rules
"""
dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_dgated_act_lu_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
)
else:
ir_dz_type = ir.RankedTensorType(dz.type)
ir_dz_shape = ir_dz_type.shape
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
x_batch_size = reduce(operator.mul, x_shape[:-2])
assert dz_batch_szie == x_batch_size
assert x_shape[-2] == 2 # Linear + GeLU
ir_hidden_szie = ir_dz_shape[-1]
gi_hidden_size = x_shape[-1]
assert ir_hidden_szie == gi_hidden_size
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [dz, x, amax, scale, scale_inv]
operand_shapes = [
ir_dz_shape,
x_shape,
ir_amax_shape,
ir_scale_shape,
ir_scale_inv_shape,
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
contracted_x_shape = (x_batch_size, x_shape[-1])
opaque = transformer_engine_jax.pack_common_descriptor(
contracted_x_shape,
jax_dtype_to_te_dtype(dz_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)
out = custom_caller(
DgatedActLuCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 2},
)
return out
@staticmethod
def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum):
"""
to describe implementation
"""
assert DgatedActLuCastTransposePrimitive.inner_primitive is not None
out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
return out, t_out, updated_amax
@staticmethod
def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
"""
to describe batch rules for vmap
"""
del static_axis_boundary
check_valid_batch_dims(batch_dims)
assert DgatedActLuCastTransposePrimitive.outer_primitive is not None
dz, x, amax, scale, scale_inv = batched_args
x_bdim, _, amax_bdim, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, amax_bdim
return (
DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=x_bdim,
act_enum=act_enum,
),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos
):
del out_dtype, result_infos, act_enum
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
return (out_sharding, tranposed_out_sharding, amax_sharding)
@staticmethod
def partition(out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos):
del result_infos
x_spec = get_padded_spec(arg_infos[1])
casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
def sharded_impl(dz, x, amax, scale, scale_inv):
local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
register_primitive(DgatedActLuCastTransposePrimitive)
def dgated_act_lu_cast_transpose(
dz: jnp.ndarray,
x: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType,
static_axis_boundary: int,
activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose d_gated_act_lu fusion wrapper
Return FP8(dgated_act_lu(inputs))
"""
act_type_id = ActivationEnum[activation_type]
if not DgatedActLuCastTransposePrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
(dx,) = vjp_func(dz)
return _jax_cast_transpose(
dx,
scale,
amax,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=-2,
)
return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
dz,
x,
amax,
scale,
scale_inv,
out_dtype=out_dtype,
static_axis_boundary=static_axis_boundary,
act_enum=act_type_id,
)
......@@ -13,6 +13,7 @@
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include <cassert>
......@@ -33,226 +34,42 @@
namespace transformer_engine {
namespace jax {
// Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1.
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
// 2.
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
return pybind11::bytes(str);
}
// 3.
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid opaque object size");
}
return reinterpret_cast<const T *>(opaque);
}
// Packing
struct CustomCallCommonDescriptor {
Shape shape;
DType in_dtype;
DType out_dtype;
size_t act_enum;
};
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype, size_t act_enum = 0);
struct CustomCallCommonWkDescriptor {
Shape shape;
Shape wkshape;
DType in_dtype;
DType out_dtype;
DType wk_dtype;
size_t act_enum;
};
pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
const std::vector<size_t> &wkshape, DType in_dtype,
DType out_dtype, DType wk_dtype,
size_t act_enum = 0);
struct CustomCallNormDescriptor {
size_t batch_size;
size_t hidden_size;
size_t wkspace_size;
DType x_dtype;
DType w_dtype;
DType wkspace_dtype;
bool zero_centered_gamma;
float eps;
int sm_margin;
};
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, DType x_dtype, DType w_dtype,
DType wkspace_dtype, bool zero_centered_gamma,
float eps, int sm_margin);
struct SoftmaxDescriptor {
size_t batch_size;
size_t padding_size;
size_t head_dim;
size_t q_seqlen;
size_t k_seqlen;
DType dtype;
float scale_factor;
};
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
size_t head_dim, size_t q_seqlen, size_t k_seqlen,
DType dtype, float scale_factor);
struct CustomCallFusedAttnDescriptor {
size_t input_batch;
size_t bias_batch;
size_t q_max_seqlen;
size_t kv_max_seqlen;
size_t attn_heads;
size_t num_gqa_groups;
size_t bias_heads;
size_t head_dim;
size_t max_segments_per_seq;
size_t wkspace_size;
float scaling_factor;
float dropout_probability;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
NVTE_QKV_Layout qkv_layout;
DType dtype;
DType wkspace_dtype;
bool is_training;
bool deterministic;
int64_t window_size_left;
int64_t window_size_right;
};
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right);
// Transpose
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasCastTransposeHandler);
// Activation
size_t get_activation_len(NVTE_Activation_Type activation_enum);
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler);
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler);
// Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps, int sm_margin);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardHandler);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardFP8Handler);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma,
float eps, int sm_margin);
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, DType out_dtype,
NVTE_Norm_Type norm_type, int scaling_mode,
bool zero_centered_gamma, float epsilon, int sm_margin,
bool is_training);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardHandler);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardFP8Handler);
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormBackwardHandler);
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type,
bool zero_centered_gamma, int sm_margin);
// Quantization
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
// Softmax
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x);
// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);
......@@ -266,9 +83,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Attention
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
......@@ -285,10 +102,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
......@@ -297,9 +110,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
// CuBLAS helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
} // namespace jax
} // namespace transformer_engine
......
......@@ -5,328 +5,136 @@
************************************************************************/
#include "transformer_engine/activation.h"
#include <cuda_runtime.h>
#include "extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/transpose.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
// TODO: We won't need this function anymore when we move to the new XLA custom calls
size_t get_activation_len(NVTE_Activation_Type activation_enum) {
switch (activation_enum) {
case NVTE_Activation_Type::GELU:
return 1;
case NVTE_Activation_Type::GEGLU:
return 2;
case NVTE_Activation_Type::SILU:
return 1;
case NVTE_Activation_Type::SWIGLU:
return 2;
case NVTE_Activation_Type::RELU:
return 1;
case NVTE_Activation_Type::REGLU:
return 2;
case NVTE_Activation_Type::QGELU:
return 1;
case NVTE_Activation_Type::QGEGLU:
return 2;
case NVTE_Activation_Type::SRELU:
return 1;
case NVTE_Activation_Type::SREGLU:
return 2;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
return -1;
}
}
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output,
NVTE_Activation_Type act_enum, size_t act_len) {
auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_silu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_relu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto act_len = get_activation_len(act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output,
act_enum, act_len);
}
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf,
int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
auto m = product(input_dims, 0, input_dims.size() - 2);
auto n = input_dims.back();
auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
ActLuImpl(input, m, n, in_dtype, out_dtype, nullptr, stream, nullptr, nullptr, output, act_type,
act_len);
return ffi_with_cuda_error_check();
namespace {
bool is_gated(NVTE_Activation_Type act_type) {
return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU ||
act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU ||
act_type == NVTE_Activation_Type::SREGLU;
}
} // namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto act_len = get_activation_len(act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output,
act_enum, act_len);
}
namespace transformer_engine {
namespace jax {
Error_Type ActLuFP8FFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type amax_out_buf, int64_t act_enum) {
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum, int64_t scaling_mode_enum,
bool is_2x_int) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto *colwise_output = colwise_output_buf->untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto input_dims = input_buf.dimensions();
auto m = product(input_dims, 0, input_dims.size() - 2);
auto n = input_dims.back();
auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto act_len = input_dims[input_dims.size() - 2];
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto is_2x = static_cast<bool>(is_2x_int);
ActLuImpl(input, m, n, in_dtype, out_dtype, scale, stream, scale_inv, amax_out, output, act_type,
act_len);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuFP8Handler, ActLuFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
auto input_shape = std::vector<size_t>{m, act_len * n};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
}
Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf,
Result_Type output_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
auto *output = output_buf->untyped_data();
auto act_input_dims = act_input_buf.dimensions();
auto m = static_cast<size_t>(product(act_input_dims, 0, act_input_dims.size() - 2));
auto n = static_cast<size_t>(act_input_dims.back());
auto act_len = act_input_dims.end()[-2];
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype));
if (is_2x) {
output_tensor.set_columnwise_data(colwise_output, static_cast<DType>(out_dtype), output_shape);
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
}
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) {
case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_gelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_silu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_swiglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_relu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_reglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_qgelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_qgeglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_srelu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
nvte_sreglu(input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI,
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Arg<Buffer_Type>() // scale
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"),
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum")
.Attr<int64_t>("scaling_mode")
.Attr<bool>("is_2x"),
FFI_CudaGraph_Traits);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
int scaling_mode, bool is_2x) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
......@@ -344,13 +152,34 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_
auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto dact_input_tensor =
TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper();
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
auto output_tensor = TensorWrapper(static_cast<NVTEScalingMode>(scaling_mode));
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
TensorWrapper dummy_workspace;
if (is_2x) {
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype,
output_trans_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
}
if (is_fp8_dtype(out_dtype) && scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
std::vector<size_t>{1});
}
TensorWrapper dummy_workspace;
// For now, all dbias_dact(-s) have the same workspace size
nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);
......@@ -359,101 +188,26 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_
return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
}
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
auto *dbias = buffers[7];
float *amax_out = reinterpret_cast<float *>(buffers[8]);
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto dbias_shape = std::vector<size_t>{n};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape);
output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);
auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);
switch (act_enum) {
case NVTE_Activation_Type::GELU:
nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type dbias_buf, Result_Type amax_out_buf,
Result_Type workspace_buf, int64_t act_enum) {
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x,
bool is_dbias, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
void *workspace = workspace_buf->untyped_data();
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto act_input_dims = act_input_buf.dimensions();
......@@ -461,212 +215,156 @@ Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims
auto input_ranks = input_dims.size();
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = product(input_dims, input_ranks - 1, input_ranks);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_ranks = act_input_dims.size();
auto m = product(act_input_dims, 0, act_input_dims.size() - 1);
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated
auto n = act_input_dims.back();
auto input_shape = std::vector<size_t>{m, input_dims.back()};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto output_trans_shape = std::vector<size_t>{m, n};
auto dbias_shape = std::vector<size_t>{n};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
}
}
if (is_2x) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
}
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x &&
is_gated(act_type)),
"TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) {
switch (act_type) {
case NVTE_Activation_Type::GELU:
nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), workspace_tensor.data(), stream);
nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), dbias_tensor.data(),
workspace_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
NVTE_ERROR("Unsupported ActivationEnum = ", act_enum, "with dbias = True");
break;
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler, DActLuDBiasCastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *act_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape);
output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
switch (act_enum) {
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
} else {
switch (act_type) {
case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
stream);
case NVTE_Activation_Type::SILU:
nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
}
}
Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type amax_out_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(amax == amax_out,
"amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive.");
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
auto act_input_dims = act_input_buf.dimensions();
auto act_input_ranks = act_input_dims.size();
auto m = product(act_input_dims, 0, act_input_ranks - 2);
auto n = product(act_input_dims, act_input_ranks - 1, act_input_ranks);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) {
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
stream);
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
stream);
nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(),
output_tensor.data(), stream);
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler, DGatedActLuCastTransposeFFI,
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // output_trans
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias")
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
......@@ -301,39 +301,6 @@ static void FusedAttnForwardImpl(
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *seed = buffers[4];
void *q_cu_seqlens = buffers[5];
void *kv_cu_seqlens = buffers[6];
void *q_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[8] : nullptr;
/* Output buffer from XLA */
void *output = buffers[9];
void *softmax_aux = buffers[10];
void *rng_state = buffers[11];
void *workspace = buffers[12];
FusedAttnForwardImpl(
stream, q, k, v, bias, seed, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets,
output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch,
descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads,
descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}
#define FUSED_ATTN_FFI_GET_ATTRS \
size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch"); \
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch"); \
......@@ -608,45 +575,6 @@ static void FusedAttnBackwardImpl(
nvte_tensor_pack_destroy(&aux_input_tensors);
}
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *softmax_aux = buffers[4];
void *rng_state = buffers[5];
void *output = buffers[6];
void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
void *q_seq_offsets = is_ragged ? buffers[10] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[11] : nullptr;
/* Output buffer from XLA */
void *dq = buffers[12];
void *dk = buffers[13];
void *dv = buffers[14];
void *dbias = buffers[15];
void *workspace = buffers[16];
FusedAttnBackwardImpl(
stream, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlens, kv_cu_seqlens,
q_seq_offsets, k_seq_offsets, dq, dk, dv, dbias, workspace, descriptor.input_batch,
descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen,
descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}
Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf,
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/gemm.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type CublasHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets,
Dictionary attrs) {
nvte_cublas_handle_init();
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CublasHandleInitHandler, CublasHandleInitFFI,
FFI::Bind<FFI_Prepare>().RemainingArgs().RemainingRets().Attrs());
} // namespace jax
} // namespace transformer_engine
......@@ -13,8 +13,9 @@ namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
switch (type) {
// Using this for E8M0
case xla::ffi::DataType::U8:
return DType::kByte;
return DType::kFloat8E8M0;
break;
case xla::ffi::DataType::S32:
return DType::kInt32;
......@@ -37,8 +38,12 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3;
break;
// case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0;
// break;
default:
auto type_num = static_cast<XLA_FFI_DataType>(type);
if (type_num == 33) return DType::kFloat8E8M0;
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num));
break;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment