Unverified Commit cf9a7c2f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Refactor + MXFP8 + GroupedGEMM (#1627)



* refactor + mxfp8

* added grouped gemm

* rename linear to dense

* added cublas init phase for groupedGemm

* relax the tol of test encoder multiprocessing mxfp8 by 0.001
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent be055eb0
...@@ -577,3 +577,11 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -577,3 +577,11 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
} }
} }
namespace transformer_engine {
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }
} // namespace transformer_engine
...@@ -119,6 +119,13 @@ namespace transformer_engine { ...@@ -119,6 +119,13 @@ namespace transformer_engine {
constexpr int num_streams = 4; constexpr int num_streams = 4;
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
* region. This function is a helper to call cublasCreate() which allocate memory for the handle.
* The function will be called in the initialize phase of the related XLA custom calls.
*/
void nvte_cublas_handle_init();
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_H_ #endif // TRANSFORMER_ENGINE_GEMM_H_
...@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor ...@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void nvte_enable_cudnn_norm_fwd(bool enable); void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable); void nvte_enable_cudnn_norm_bwd(bool enable);
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -80,7 +80,8 @@ enum NVTEScalingMode { ...@@ -80,7 +80,8 @@ enum NVTEScalingMode {
/*! Single scale per block of 32 elements consecutive in either /*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */ rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING = 1, NVTE_MXFP8_1D_SCALING = 1,
NVTE_INVALID_SCALING NVTE_INVALID_SCALING = 2,
NVTE_NO_SCALING = 3
}; };
/*! \brief TE Tensor type /*! \brief TE Tensor type
...@@ -346,6 +347,13 @@ enum class DType { ...@@ -346,6 +347,13 @@ enum class DType {
kNumTypes kNumTypes
}; };
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
*/
bool is_fp8_dtype(const DType t);
/*! \struct TensorWrapper /*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class. * \brief C++ wrapper for the NVTETensor class.
*/ */
......
...@@ -11,10 +11,12 @@ ...@@ -11,10 +11,12 @@
transformer_engine::ubuf_built_with_mpi*; transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*; *transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*; transformer_engine::nvte_cudnn_handle_init*;
transformer_engine::nvte_cublas_handle_init*;
transformer_engine::typeToSize*; transformer_engine::typeToSize*;
transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*; *transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore* *transformer_engine::CommOverlapCore*
}; };
local: *; local: *;
}; };
\ No newline at end of file
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <cudnn.h> #include <cudnn.h>
#include <cudnn_frontend.h> #include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h> #include <cudnn_frontend_utils.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <functional> #include <functional>
...@@ -137,7 +138,6 @@ struct BackwardKernelParams : public KernelParamsBase { ...@@ -137,7 +138,6 @@ struct BackwardKernelParams : public KernelParamsBase {
}; };
enum class NVTE_Norm_Backend { Te, Cudnn }; enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
enum class NVTE_Norm_Stage { Forward, Backward }; enum class NVTE_Norm_Stage { Forward, Backward };
using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>; using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>;
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX.
This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.
The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training
All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
import sys
import logging import logging
import importlib import importlib
import importlib.util import importlib.util
import ctypes
from importlib.metadata import version from importlib.metadata import version
import sys
from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
...@@ -41,7 +55,7 @@ def _load_library(): ...@@ -41,7 +55,7 @@ def _load_library():
if is_package_installed("transformer-engine-cu12"): if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name): if not is_package_installed(module_name):
_logger.info( logging.info(
"Could not find package %s. Install transformer-engine using " "Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'", "'pip3 install transformer-engine[jax]==VERSION'",
module_name, module_name,
...@@ -67,8 +81,10 @@ def _load_library(): ...@@ -67,8 +81,10 @@ def _load_library():
_load_library() _load_library()
from . import flax from . import flax
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling from . import quantize
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .quantize import fp8_autocast
from .sharding import MeshResource from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType from .sharding import MajorShardingType, ShardingResource, ShardingType
...@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper( ...@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper(
) )
__all__ = [ __all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast", "fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource", "MeshResource",
"MajorShardingType", "MajorShardingType",
"ShardingResource", "ShardingResource",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.
This module provides optimized activation functions with quantization support.
"""
from typing import Sequence, Union, Callable, Optional
from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize.tensor import ScaledTensor
from .quantize.quantizer import Quantizer
def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor.
It supports string-based activation types (e.g., 'relu', 'gelu', ('gelu', 'linear')).
Args:
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
"""Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for
custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
return _output
def _activation_fwd_rule(x, activation_type, quantizer):
"""Forward pass rule for activation function.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
if isinstance(fwd_output, ScaledTensor):
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g):
"""Backward pass rule for activation function.
Args:
activation_type: Sequence of activation functions
ctx: Context from forward pass
g: Gradient from upstream
Returns:
Gradient with respect to input
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None)
_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule)
...@@ -7,4 +7,4 @@ from .attention import * ...@@ -7,4 +7,4 @@ from .attention import *
from .normalization import * from .normalization import *
from .quantization import * from .quantization import *
from .softmax import * from .softmax import *
from .transpose import * from .gemm import *
...@@ -13,8 +13,6 @@ from packaging import version ...@@ -13,8 +13,6 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, lax from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax import transformer_engine_jax
...@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import ( ...@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import (
) )
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import ( from .misc import (
check_valid_batch_dims, check_valid_batch_dims,
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
get_padded_spec, get_padded_spec,
get_cudnn_version, get_cudnn_version,
is_ffi_enabled,
) )
from ..sharding import ( from ..sharding import (
global_mesh_resource, global_mesh_resource,
...@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
Fused Attention Forward Primitive Fused Attention Forward Primitive
""" """
name = "te_fused_attn_forward" name = "te_fused_attn_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (13,) impl_static_args = (13,)
inner_primitive = None inner_primitive = None
...@@ -400,90 +396,40 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -400,90 +396,40 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled(): return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
name = "te_fused_attn_forward_ffi" ctx,
out = ffi.ffi_lowering(name)( q,
ctx, k,
q, v,
k, bias,
v, seed,
bias, q_cu_seqlen,
seed, kv_cu_seqlen,
q_cu_seqlen, q_seq_offsets,
kv_cu_seqlen, k_seq_offsets,
q_seq_offsets, _q_segment_ids,
k_seq_offsets, _kv_segment_ids,
_q_segment_ids, _q_segment_pos,
_kv_segment_ids, _kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
_q_segment_pos, input_batch=input_batch,
_kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering bias_batch=bias_batch,
input_batch=input_batch, q_max_seqlen=q_max_seqlen,
bias_batch=bias_batch, kv_max_seqlen=kv_max_seqlen,
q_max_seqlen=q_max_seqlen, attn_heads=attn_heads,
kv_max_seqlen=kv_max_seqlen, num_gqa_groups=num_gqa_groups,
attn_heads=attn_heads, bias_heads=bias_heads,
num_gqa_groups=num_gqa_groups, head_dim=head_dim,
bias_heads=bias_heads, max_segments_per_seq=config.max_segments_per_seq,
head_dim=head_dim, scaling_factor=float(config.scaling_factor),
max_segments_per_seq=config.max_segments_per_seq, dropout_probability=float(config.dropout_probability),
scaling_factor=float(config.scaling_factor), bias_type=int(config.attn_bias_type.value),
dropout_probability=float(config.dropout_probability), mask_type=int(config.attn_mask_type.value),
bias_type=int(config.attn_bias_type.value), qkv_layout=int(config.qkv_layout.value),
mask_type=int(config.attn_mask_type.value), is_training=config.is_training,
qkv_layout=int(config.qkv_layout.value), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
is_training=config.is_training, window_size_left=config.window_size[0],
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_right=config.window_size[1],
window_size_left=config.window_size[0], )
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod @staticmethod
def impl( def impl(
...@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
Fused Attention Backward Primitive Fused Attention Backward Primitive
""" """
name = "te_fused_attn_backward" name = "te_fused_attn_backward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (16,) impl_static_args = (16,)
inner_primitive = None inner_primitive = None
...@@ -813,96 +759,43 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -813,96 +759,43 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled(): return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
name = "te_fused_attn_backward_ffi" ctx,
out = ffi.ffi_lowering(name)( q,
ctx, k,
q, v,
k, bias,
v, softmax_aux,
bias, rng_state,
softmax_aux, output,
rng_state, doutput,
output, q_cu_seqlen,
doutput, kv_cu_seqlen,
q_cu_seqlen, q_seq_offsets,
kv_cu_seqlen, k_seq_offsets,
q_seq_offsets, q_segment_ids,
k_seq_offsets, kv_segment_ids,
q_segment_ids, q_segment_pos,
kv_segment_ids, kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
q_segment_pos, input_batch=input_batch,
kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering bias_batch=bias_batch,
input_batch=input_batch, q_max_seqlen=q_max_seqlen,
bias_batch=bias_batch, kv_max_seqlen=kv_max_seqlen,
q_max_seqlen=q_max_seqlen, attn_heads=attn_heads,
kv_max_seqlen=kv_max_seqlen, num_gqa_groups=num_gqa_groups,
attn_heads=attn_heads, bias_heads=bias_heads,
num_gqa_groups=num_gqa_groups, head_dim=head_dim,
bias_heads=bias_heads, max_segments_per_seq=config.max_segments_per_seq,
head_dim=head_dim, scaling_factor=float(config.scaling_factor),
max_segments_per_seq=config.max_segments_per_seq, dropout_probability=float(config.dropout_probability),
scaling_factor=float(config.scaling_factor), bias_type=int(config.attn_bias_type.value),
dropout_probability=float(config.dropout_probability), mask_type=int(config.attn_mask_type.value),
bias_type=int(config.attn_bias_type.value), qkv_layout=int(config.qkv_layout.value),
mask_type=int(config.attn_mask_type.value), is_training=config.is_training,
qkv_layout=int(config.qkv_layout.value), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
is_training=config.is_training, window_size_left=config.window_size[0],
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_right=config.window_size[1],
window_size_left=config.window_size[0], )
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod @staticmethod
def impl( def impl(
......
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import re import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from functools import partial from functools import partial
from packaging import version
from jax.extend import core from jax.extend import core
from jax.interpreters import xla, mlir from jax.interpreters import xla, mlir
...@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning ...@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src import dispatch from jax._src import dispatch
import jax
import transformer_engine_jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
class BasePrimitive(metaclass=ABCMeta): class BasePrimitive(metaclass=ABCMeta):
""" """
...@@ -120,3 +129,7 @@ def register_primitive(cls): ...@@ -120,3 +129,7 @@ def register_primitive(cls):
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
) )
cls.outer_primitive = outer_p cls.outer_primitive = outer_p
for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom call"""
from dataclasses import dataclass
from enum import IntEnum
from packaging import version
import jax
from jax.interpreters import mlir
import transformer_engine_jax
from .misc import is_ffi_enabled
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
try:
from jaxlib.hlo_helpers import custom_call
except ImportError:
# Newer JAX changed its API. But we want to support a few JAX
# version, so we still need this import.
pass
class CustomCallAPIVersion(IntEnum):
"""Enum for selecting between old and new custom call registration API"""
OPAQUE = 0
FFI = 1
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
@dataclass
class CustomCallArgsWrapper:
"""
wrapper of XLA custom call args
"""
def __init__(
self,
output_types,
operands,
operand_shapes,
operand_specific_layouts=None,
output_specific_layouts=None,
):
self.output_types = output_types
self.operands = operands
self.operand_layouts = CustomCallArgsWrapper.generate_layouts(
operand_shapes, operand_specific_layouts
)
output_shapes = [x.shape for x in output_types]
self.output_layouts = CustomCallArgsWrapper.generate_layouts(
output_shapes, output_specific_layouts
)
@staticmethod
def generate_layouts(shapes, specific_layouts):
"""
setup layouts for XLA custom call
"""
def default_layout(shape):
return range(len(shape) - 1, -1, -1)
if specific_layouts is None:
specific_layouts = {}
layouts = []
for idx, shape in enumerate(shapes):
if idx in specific_layouts:
layouts.append(specific_layouts[idx])
else:
layouts.append(default_layout(shape))
return layouts
def custom_caller(name, args, opaque, has_side_effect, **kwargs):
"""
XLA custom call warpper
"""
if hasattr(mlir, "custom_call"):
out = mlir.custom_call(
name,
result_types=args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
).results
else:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out = custom_call( # pylint: disable=too-many-function-args
name,
args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
)
return out
This diff is collapsed.
...@@ -11,14 +11,17 @@ from packaging.version import Version as PkgVersion ...@@ -11,14 +11,17 @@ from packaging.version import Version as PkgVersion
import numpy as np import numpy as np
import jax.numpy as jnp import jax
from jax import dtypes from jax import dtypes
import jax.numpy as jnp
from jax.interpreters.mlir import dtype_to_ir_type from jax.interpreters.mlir import dtype_to_ir_type
from transformer_engine_jax import DType as TEDType
import transformer_engine_jax import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis
TEDType = transformer_engine_jax.DType
def te_dtype_to_jax_dtype(te_dtype): def te_dtype_to_jax_dtype(te_dtype):
...@@ -104,7 +107,7 @@ def normalize_axis_boundary(axis, ndim): ...@@ -104,7 +107,7 @@ def normalize_axis_boundary(axis, ndim):
return axis if axis >= 0 else ndim + axis return axis if axis >= 0 else ndim + axis
def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1):
""" """
te_cast_transpose_p multi-dims transpose te_cast_transpose_p multi-dims transpose
...@@ -158,17 +161,6 @@ def jax_version_meet_requirement(version: str): ...@@ -158,17 +161,6 @@ def jax_version_meet_requirement(version: str):
return jax_version >= jax_version_required return jax_version >= jax_version_required
def is_ffi_enabled():
"""
Helper function checking if XLA Custom Call with FFI is enabled
"""
is_supported = jax_version_meet_requirement("0.4.35")
# New APIs with FFI are enabled by default
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
return is_supported and is_enabled
def get_xla_flag(flag: str, default=None, cast=str): def get_xla_flag(flag: str, default=None, cast=str):
""" """
Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
...@@ -189,3 +181,86 @@ def get_xla_flag(flag: str, default=None, cast=str): ...@@ -189,3 +181,86 @@ def get_xla_flag(flag: str, default=None, cast=str):
if name == flag: if name == flag:
return True return True
return default return default
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
"""
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
calculate dbias separately. This function checks if the workaround should be applied.
"""
arch_l_100 = False
for local_gpu_id in range(len(jax.local_devices())):
if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100:
arch_l_100 = True
break
return (
quantizer is not None
and quantizer.q_axis == QuantizeAxis.ROWWISE
and arch_l_100
and is_dbias
)
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
"""
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
If 'f' returns a tuple, the first output must be the only ScaledTensor output.
@param f: function to call
@param args: positional arguments to pass to 'f'
@param quantizer: quantizer to use
@param kwargs: keyword arguments to pass to 'f'
@return: the output of 'f' with the colwise output calculated
"""
should_apply_war = (
quantizer is not None
and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
and quantizer.is_2x2x()
)
if not should_apply_war:
return None
# 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX
quantizer.q_axis = QuantizeAxis.ROWWISE
rowwise = f(*args, **kwargs, quantizer=quantizer)
other_outputs = None
if isinstance(rowwise, tuple):
other_outputs = rowwise[1:]
rowwise = rowwise[0]
quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE
colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1)))
output_2x = ScaledTensorFactory.create(
data=rowwise.data,
scale_inv=rowwise.scale_inv,
colwise_data=colwise_data,
colwise_scale_inv=rowwise.scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=rowwise.dq_dtype,
q_axis=QuantizeAxis.ROWWISE_COLWISE,
layout=quantizer.get_layout(),
)
if other_outputs is not None:
return (output_2x,) + other_outputs
return output_2x
class NamedSharding(jax.sharding.NamedSharding):
"""
Wrapper around jax.sharding.NamedSharding that adds a string description field as metadata for easier debugging.
"""
def __init__(self, *args, desc: str = None, **kwargs):
super().__init__(*args, **kwargs)
self.desc = desc
def __repr__(self):
return f"NamedSharding({self.mesh}, {self.spec}, desc={self.desc})"
def duplicate_with_new_description(self, desc: str):
"""
Create a new NamedSharding with the same mesh and spec but with a new description.
"""
return NamedSharding(self.mesh, self.spec, desc=desc)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment