Unverified Commit cf9a7c2f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Refactor + MXFP8 + GroupedGEMM (#1627)



* refactor + mxfp8

* added grouped gemm

* rename linear to dense

* added cublas init phase for groupedGemm

* relax the tol of test encoder multiprocessing mxfp8 by 0.001
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent be055eb0
...@@ -577,3 +577,11 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -577,3 +577,11 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
} }
} }
namespace transformer_engine {
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }
} // namespace transformer_engine
...@@ -119,6 +119,13 @@ namespace transformer_engine { ...@@ -119,6 +119,13 @@ namespace transformer_engine {
constexpr int num_streams = 4; constexpr int num_streams = 4;
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
* region. This function is a helper to call cublasCreate() which allocate memory for the handle.
* The function will be called in the initialize phase of the related XLA custom calls.
*/
void nvte_cublas_handle_init();
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_H_ #endif // TRANSFORMER_ENGINE_GEMM_H_
...@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor ...@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void nvte_enable_cudnn_norm_fwd(bool enable); void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable); void nvte_enable_cudnn_norm_bwd(bool enable);
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -80,7 +80,8 @@ enum NVTEScalingMode { ...@@ -80,7 +80,8 @@ enum NVTEScalingMode {
/*! Single scale per block of 32 elements consecutive in either /*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */ rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING = 1, NVTE_MXFP8_1D_SCALING = 1,
NVTE_INVALID_SCALING NVTE_INVALID_SCALING = 2,
NVTE_NO_SCALING = 3
}; };
/*! \brief TE Tensor type /*! \brief TE Tensor type
...@@ -346,6 +347,13 @@ enum class DType { ...@@ -346,6 +347,13 @@ enum class DType {
kNumTypes kNumTypes
}; };
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
*/
bool is_fp8_dtype(const DType t);
/*! \struct TensorWrapper /*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class. * \brief C++ wrapper for the NVTETensor class.
*/ */
......
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
transformer_engine::ubuf_built_with_mpi*; transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*; *transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*; transformer_engine::nvte_cudnn_handle_init*;
transformer_engine::nvte_cublas_handle_init*;
transformer_engine::typeToSize*; transformer_engine::typeToSize*;
transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*; *transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore* *transformer_engine::CommOverlapCore*
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <cudnn.h> #include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h> #include <cudnn_frontend_utils.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <functional> #include <functional>
...@@ -137,7 +138,6 @@ struct BackwardKernelParams : public KernelParamsBase { ...@@ -137,7 +138,6 @@ struct BackwardKernelParams : public KernelParamsBase {
}; };
enum class NVTE_Norm_Backend { Te, Cudnn }; enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
enum class NVTE_Norm_Stage { Forward, Backward }; enum class NVTE_Norm_Stage { Forward, Backward };
using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>; using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>;
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX.
This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.
The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training
All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
import sys
import logging import logging
import importlib import importlib
import importlib.util import importlib.util
import ctypes
from importlib.metadata import version from importlib.metadata import version
import sys
from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
...@@ -41,7 +55,7 @@ def _load_library(): ...@@ -41,7 +55,7 @@ def _load_library():
if is_package_installed("transformer-engine-cu12"): if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name): if not is_package_installed(module_name):
_logger.info( logging.info(
"Could not find package %s. Install transformer-engine using " "Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'", "'pip3 install transformer-engine[jax]==VERSION'",
module_name, module_name,
...@@ -67,8 +81,10 @@ def _load_library(): ...@@ -67,8 +81,10 @@ def _load_library():
_load_library() _load_library()
from . import flax from . import flax
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling from . import quantize
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .quantize import fp8_autocast
from .sharding import MeshResource from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType from .sharding import MajorShardingType, ShardingResource, ShardingType
...@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper( ...@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper(
) )
__all__ = [ __all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast", "fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource", "MeshResource",
"MajorShardingType", "MajorShardingType",
"ShardingResource", "ShardingResource",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.
This module provides optimized activation functions with quantization support.
"""
from typing import Sequence, Union, Callable, Optional
from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize.tensor import ScaledTensor
from .quantize.quantizer import Quantizer
def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor.
It supports string-based activation types (e.g., 'relu', 'gelu', ('gelu', 'linear')).
Args:
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
"""Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for
custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
return _output
def _activation_fwd_rule(x, activation_type, quantizer):
"""Forward pass rule for activation function.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
if isinstance(fwd_output, ScaledTensor):
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g):
"""Backward pass rule for activation function.
Args:
activation_type: Sequence of activation functions
ctx: Context from forward pass
g: Gradient from upstream
Returns:
Gradient with respect to input
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None)
_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule)
...@@ -7,4 +7,4 @@ from .attention import * ...@@ -7,4 +7,4 @@ from .attention import *
from .normalization import * from .normalization import *
from .quantization import * from .quantization import *
from .softmax import * from .softmax import *
from .transpose import * from .gemm import *
...@@ -13,8 +13,6 @@ from packaging import version ...@@ -13,8 +13,6 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, lax from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax import transformer_engine_jax
...@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import ( ...@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import (
) )
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import ( from .misc import (
check_valid_batch_dims, check_valid_batch_dims,
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
get_padded_spec, get_padded_spec,
get_cudnn_version, get_cudnn_version,
is_ffi_enabled,
) )
from ..sharding import ( from ..sharding import (
global_mesh_resource, global_mesh_resource,
...@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
Fused Attention Forward Primitive Fused Attention Forward Primitive
""" """
name = "te_fused_attn_forward" name = "te_fused_attn_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (13,) impl_static_args = (13,)
inner_primitive = None inner_primitive = None
...@@ -400,9 +396,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -400,9 +396,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled(): return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx, ctx,
q, q,
k, k,
...@@ -436,54 +430,6 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -436,54 +430,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
window_size_left=config.window_size[0], window_size_left=config.window_size[0],
window_size_right=config.window_size[1], window_size_right=config.window_size[1],
) )
else:
operands = [
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod @staticmethod
def impl( def impl(
...@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
Fused Attention Backward Primitive Fused Attention Backward Primitive
""" """
name = "te_fused_attn_backward" name = "te_fused_attn_backward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (16,) impl_static_args = (16,)
inner_primitive = None inner_primitive = None
...@@ -813,9 +759,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -813,9 +759,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled(): return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
name = "te_fused_attn_backward_ffi"
out = ffi.ffi_lowering(name)(
ctx, ctx,
q, q,
k, k,
...@@ -852,57 +796,6 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -852,57 +796,6 @@ class FusedAttnBwdPrimitive(BasePrimitive):
window_size_left=config.window_size[0], window_size_left=config.window_size[0],
window_size_right=config.window_size[1], window_size_right=config.window_size[1],
) )
else:
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod @staticmethod
def impl( def impl(
......
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import re import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from functools import partial from functools import partial
from packaging import version
from jax.extend import core from jax.extend import core
from jax.interpreters import xla, mlir from jax.interpreters import xla, mlir
...@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning ...@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src import dispatch from jax._src import dispatch
import jax
import transformer_engine_jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
class BasePrimitive(metaclass=ABCMeta): class BasePrimitive(metaclass=ABCMeta):
""" """
...@@ -120,3 +129,7 @@ def register_primitive(cls): ...@@ -120,3 +129,7 @@ def register_primitive(cls):
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
) )
cls.outer_primitive = outer_p cls.outer_primitive = outer_p
for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom call"""
from dataclasses import dataclass
from enum import IntEnum
from packaging import version
import jax
from jax.interpreters import mlir
import transformer_engine_jax
from .misc import is_ffi_enabled
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
try:
from jaxlib.hlo_helpers import custom_call
except ImportError:
# Newer JAX changed its API. But we want to support a few JAX
# version, so we still need this import.
pass
class CustomCallAPIVersion(IntEnum):
"""Enum for selecting between old and new custom call registration API"""
OPAQUE = 0
FFI = 1
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
@dataclass
class CustomCallArgsWrapper:
"""
wrapper of XLA custom call args
"""
def __init__(
self,
output_types,
operands,
operand_shapes,
operand_specific_layouts=None,
output_specific_layouts=None,
):
self.output_types = output_types
self.operands = operands
self.operand_layouts = CustomCallArgsWrapper.generate_layouts(
operand_shapes, operand_specific_layouts
)
output_shapes = [x.shape for x in output_types]
self.output_layouts = CustomCallArgsWrapper.generate_layouts(
output_shapes, output_specific_layouts
)
@staticmethod
def generate_layouts(shapes, specific_layouts):
"""
setup layouts for XLA custom call
"""
def default_layout(shape):
return range(len(shape) - 1, -1, -1)
if specific_layouts is None:
specific_layouts = {}
layouts = []
for idx, shape in enumerate(shapes):
if idx in specific_layouts:
layouts.append(specific_layouts[idx])
else:
layouts.append(default_layout(shape))
return layouts
def custom_caller(name, args, opaque, has_side_effect, **kwargs):
"""
XLA custom call warpper
"""
if hasattr(mlir, "custom_call"):
out = mlir.custom_call(
name,
result_types=args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
).results
else:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out = custom_call( # pylint: disable=too-many-function-args
name,
args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
)
return out
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""
from typing import Tuple, Sequence, Union, Dict, List
from functools import partial, reduce
import operator
from transformer_engine_jax import get_device_compute_capability
import jax
import jax.numpy as jnp
from .base import BasePrimitive, register_primitive
from ..quantize import (
ScaledTensor,
ScalingMode,
Quantizer,
QuantizeConfig,
noop_quantizer_set,
)
__all__ = ["gemm", "grouped_gemm"]
num_cublas_streams = 4
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if get_device_compute_capability(0) >= 90:
return 33_554_432
return 4_194_304
class GroupedGemmPrimitive(BasePrimitive):
"""
Primitive for grouped GEMM
"""
name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
lhs_contig_aval,
lhs_scale_contig_aval,
rhs_contig_aval,
rhs_scale_contig_aval,
bias_contig_aval,
dim_list_aval,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
):
del lhs_contig_aval, lhs_scale_contig_aval
del rhs_contig_aval, rhs_scale_contig_aval
del bias_contig_aval, dim_list_aval
del num_gemms, scaling_mode
out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype)
wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8)
return (out_flat_aval, wkspace_aval)
@staticmethod
def outer_abstract(*args, **kwargs):
(out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
return out_aval
@staticmethod
def lowering(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
del out_dtype, out_flat_size
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=int(scaling_mode),
)
@staticmethod
def impl(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
)
return out[0] # out is [out_flat, wkspace], only return out_flat
register_primitive(GroupedGemmPrimitive)
def _shape_normalization(x, dimension_numbers, already_transposed: bool = False):
orig_order = list(range(x.ndim))
contracting_dims, batch_dims = dimension_numbers
contracting_order = [d for d in orig_order if d in contracting_dims]
batch_order = [d for d in orig_order if d in batch_dims]
non_contracting_order = [
d for d in orig_order if d not in contracting_dims and d not in batch_dims
]
batch_shape = [x.shape[d] for d in batch_order]
rows_shape = [x.shape[d] for d in non_contracting_order]
cols_shape = [x.shape[d] for d in contracting_order]
new_order = batch_order + non_contracting_order + contracting_order
rows, cols, batches = (
reduce(operator.mul, rows_shape, 1),
reduce(operator.mul, cols_shape, 1),
reduce(operator.mul, batch_shape, 1),
)
# Remove this transpose when non-TN dot is supported
if not already_transposed:
t = jnp.transpose(x, new_order)
else:
t = x
return jnp.reshape(t, (batches, rows, cols))
def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)
def _dequantize(x, scale_inv, dq_dtype):
return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)
# Apply jit to guarantee correctness of FP8 GEMM.
@partial(
jax.jit,
static_argnums=(
2,
3,
4,
),
)
def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T")
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
)
return out_3d
def _jax_gemm_delayed_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs.data.shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs.data.shape, rhs_contract)
precision = (
jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
)
out_3d = __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
# Reshape [B, M, N] -> [..., M, N]
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
def _jax_gemm_mxfp8_1d(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""
JAX GEMM for MXFP8 via scaled_matmul
"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
expected_lhs_is_colwise = lhs_contract[-1] != lhs.data.ndim - 1
expected_rhs_is_colwise = rhs_contract[-1] != rhs.data.ndim - 1
assert lhs.is_colwise is expected_lhs_is_colwise, (
f"LHS with unexpected quantize dimension.\nExpect is_colwise={expected_lhs_is_colwise}, got"
f" {lhs.is_colwise}"
)
assert rhs.is_colwise is expected_rhs_is_colwise, (
f"RHS with unexpected quantize dimension.\nExpect is_colwise={expected_rhs_is_colwise}, got"
f" {rhs.is_colwise}"
)
# Reshape + Transpose (if needed)
# [..., M, K] -> [1, reduce(..., M), K]
# [..., K, M] -> [1, reduce(..., M), K]
lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch))
rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch))
lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch))
rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch))
# Slice out the padding as scaled_matmul does not support padded scales yet
lhs_scale_3d = jnp.asarray(lhs_scale_3d[:, : lhs_3d.shape[1], : int(lhs_3d.shape[2] / 32)])
rhs_scale_3d = jnp.asarray(rhs_scale_3d[:, : rhs_3d.shape[1], : int(rhs_3d.shape[2] / 32)])
# JAX scaled_matmul only supports NT now (TN-gemm)
# * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = scaled_matmul_wrapper(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
)
# Reshape [1, reduce(..., M), N] -> [..., M, N]
lhs_remain_shape = tuple(
lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract
)
rhs_remain_shape = tuple(
rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract
)
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
def _jax_gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set,
) -> jnp.ndarray:
"""
FP8 GEMM via JAX
"""
dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
return _jax_gemm_fp8_impl(lhs, rhs)
if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor):
if quantizer_set != noop_quantizer_set:
assert type(quantizer_set.x) is type(quantizer_set.kernel)
(((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
# Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm)
lhs_q = quantizer_set.x.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = quantizer_set.kernel.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return _jax_gemm_fp8_impl(lhs_q, rhs_q)
if (
isinstance(lhs, jnp.ndarray)
and isinstance(rhs, jnp.ndarray)
and quantizer_set == noop_quantizer_set
):
return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype)
raise NotImplementedError("Not supporting multiplication of ScaledTensor and jnp.array")
def gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set,
) -> jnp.ndarray:
"""General matrix multiplication with optional quantization.
Args:
lhs: First input matrix.
rhs: Second input matrix.
contracting_dims: Tuple of two sequences representing the contracting dimensions.
The first sequence represents the contracting dimensions of the first matrix,
and the second sequence represents the contracting dimensions of the second matrix.
quantizer_set: Set of quantizers for FP8 quantization of the output.
If None, no quantization is applied and the output has the same dtype as the inputs.
Returns:
If quantizer_set is None:
The matrix multiplication result.
Shape: (M, N)
Dtype: Same as input dtype
If quantizer_set is provided:
A ScaledTensor containing the quantized matrix multiplication result.
"""
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
def swizzled_scale(scales):
"""Swizzle the scale tensor for FP8 GEMM"""
assert scales.ndim == 2
rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
return scales
def grouped_gemm(
lhs_list: List[Union[jnp.ndarray, ScaledTensor]],
rhs_list: List[Union[jnp.ndarray, ScaledTensor]],
contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]],
bias_list: List[jnp.ndarray] = None,
) -> List[jnp.ndarray]:
"""Grouped GEMM for multiple pairs of tensors."""
assert (
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
# Flatten inputs and save their shapes
num_gemms = len(lhs_list)
out_flat_size = 0
dims = []
lhs_contig_ = []
rhs_contig_ = []
lhs_scale_inv_contig_ = []
rhs_scale_inv_contig_ = []
bias_contig_ = []
out_offsets = []
remain_shape_list = []
num_gemms = len(lhs_list)
for i in range(num_gemms):
lhs = lhs_list[i]
rhs = rhs_list[i]
contracting_dims = contracting_dims_list[i]
dim_nums = (contracting_dims, ((), ()))
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
scaling_mode = lhs.scaling_mode
lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
out_dtype = lhs.dtype
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
# Note: if _shape_normalization() is updated to support non-TN, need to update here
# already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
# x.shape = [D1, D2]
# contracting_dims = (1, ) --> output.shape = [1, D1, D2]
# contracting_dims = (0, ) --> output.shape = [1, D2, D1]
bm = lhs_remain_shape[0]
bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1]
kr = rhs_3d.shape[-1]
remain_shape_list.append(((bm,), (bn,)))
assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})"
k = kl
if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0):
print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(
f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples"
" of 16"
)
assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0
dims.append((bm, bn, k))
lhs_contig_.append(lhs_3d.reshape(-1))
rhs_contig_.append(rhs_3d.reshape(-1))
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1))
if bias_list is not None:
bias_contig_.append(bias_list[i].reshape(-1))
out_flat_size += bm * bn
out_offsets.append(out_flat_size)
lhs_contig = jnp.concatenate(lhs_contig_)
rhs_contig = jnp.concatenate(rhs_contig_)
lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_)
rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_)
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
dim_list = jnp.array(dims, dtype=jnp.int32)
# Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms,
scaling_mode=scaling_mode,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
)
# Split the output back into tensors
out_offsets = jnp.array(out_offsets)
out_flat_list = jnp.split(out_contig, out_offsets[:-1])
out_tensors = []
for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list):
out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape))
return out_tensors
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment