Commit a207db1d authored by yuguo's avatar yuguo
Browse files
parents fbee8990 69365f88
...@@ -201,8 +201,9 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf ...@@ -201,8 +201,9 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
max_fp8 = Quantized_Limits<DType>::max_norm;); max_fp8 = Quantized_Limits<DType>::max_norm;);
// Update scale // Update scale
compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast<const float *>(output.amax.dptr), compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<float *>(output.scale.dptr), max_fp8, reinterpret_cast<const float *>(output.amax.dptr),
config.force_pow_2_scales, config.amax_epsilon); reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
config.amax_epsilon);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX.
This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.
The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training
All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
import sys
import logging import logging
import importlib import importlib
import importlib.util import importlib.util
import ctypes
from importlib.metadata import version from importlib.metadata import version
import sys
from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
...@@ -41,7 +55,7 @@ def _load_library(): ...@@ -41,7 +55,7 @@ def _load_library():
if is_package_installed("transformer-engine-cu12"): if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name): if not is_package_installed(module_name):
_logger.info( logging.info(
"Could not find package %s. Install transformer-engine using " "Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'", "'pip3 install transformer-engine[jax]==VERSION'",
module_name, module_name,
...@@ -67,8 +81,10 @@ def _load_library(): ...@@ -67,8 +81,10 @@ def _load_library():
_load_library() _load_library()
from . import flax from . import flax
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling from . import quantize
from .fp8 import NVTE_FP8_COLLECTION_NAME
from .quantize import fp8_autocast
from .sharding import MeshResource from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType from .sharding import MajorShardingType, ShardingResource, ShardingType
...@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper( ...@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper(
) )
__all__ = [ __all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast", "fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource", "MeshResource",
"MajorShardingType", "MajorShardingType",
"ShardingResource", "ShardingResource",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.
This module provides optimized activation functions with quantization support.
"""
from typing import Sequence, Union, Callable, Optional
from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize.tensor import ScaledTensor
from .quantize.quantizer import Quantizer
def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor.
It supports string-based activation types (e.g., 'relu', 'gelu', ('gelu', 'linear')).
Args:
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
"""Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for
custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
return _output
def _activation_fwd_rule(x, activation_type, quantizer):
"""Forward pass rule for activation function.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
if isinstance(fwd_output, ScaledTensor):
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g):
"""Backward pass rule for activation function.
Args:
activation_type: Sequence of activation functions
ctx: Context from forward pass
g: Gradient from upstream
Returns:
Gradient with respect to input
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None)
_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule)
...@@ -378,6 +378,44 @@ def _mask_to_seqlens_offset(mask, max_segments_per_seq): ...@@ -378,6 +378,44 @@ def _mask_to_seqlens_offset(mask, max_segments_per_seq):
return q_seqlen, q_offset, kv_seqlen, kv_offset return q_seqlen, q_offset, kv_seqlen, kv_offset
def _fast_causal_adjust_seqlen_and_offsets(
segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
):
# The assumption is that for any segment tokens respect causal ordering except at the ends
# of the segment. This allows us to tweak the length and offset by only looking at the start
# and end tokens between segments.
is_active_segment = jnp.logical_and(q_len > 0, kv_len > 0)
q_seq_id_start = jnp.take(segment_pos_q, q_offset[..., :-1], fill_value=-1)
kv_seq_id_start = jnp.take(segment_pos_kv, kv_offset[..., :-1], fill_value=-1)
skip_start_token = jnp.logical_and(kv_seq_id_start > q_seq_id_start, is_active_segment).astype(
jnp.int32
)
q_len -= skip_start_token
q_offset += jnp.insert(skip_start_token, skip_start_token.shape[-1], 0, axis=-1)
q_seq_id_end = jnp.take(segment_pos_q, q_offset[..., 1:] - 1, fill_value=-1)
kv_seq_id_end = jnp.take(segment_pos_kv, kv_offset[..., 1:] - 1, fill_value=-1)
skip_end_token = jnp.logical_and(kv_seq_id_end > q_seq_id_end, is_active_segment).astype(
jnp.int32
)
kv_len -= skip_end_token
return q_len, kv_len, q_offset, kv_offset
def _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
):
q_len, q_offset = _get_seqlens_and_offsets(segment_ids_q, max_segments_per_seq)
kv_len, kv_offset = _get_seqlens_and_offsets(segment_ids_kv, max_segments_per_seq)
return _fast_causal_adjust_seqlen_and_offsets(
segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
)
def _segment_ids_pos_to_seqlens_offsets( def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q, segment_ids_q,
segment_ids_kv, segment_ids_kv,
...@@ -387,6 +425,25 @@ def _segment_ids_pos_to_seqlens_offsets( ...@@ -387,6 +425,25 @@ def _segment_ids_pos_to_seqlens_offsets(
window_size, window_size,
max_segments_per_seq, max_segments_per_seq,
): ):
# TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here.
# Computing the full mask is expensive due to quadratic expansion of Q * KV masking.
# Assumptions for cudnn causal mask correctness.
# 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0]
# 2. No intra-segment padding, only inter-segment paddding allowed
# 3. Only start or end token within a segment may violate the causal order relationship
# 1 5 9 0 4 8 10 0 4 8
# 0 x x
# 4 x x x x x
# 8 x x x x x x x x
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)
# (1 = attend, 0 = masked) # (1 = attend, 0 = masked)
segment_mask = make_attention_mask( segment_mask = make_attention_mask(
segment_ids_q, segment_ids_q,
......
...@@ -7,4 +7,4 @@ from .attention import * ...@@ -7,4 +7,4 @@ from .attention import *
from .normalization import * from .normalization import *
from .quantization import * from .quantization import *
from .softmax import * from .softmax import *
from .transpose import * from .gemm import *
...@@ -13,8 +13,6 @@ from packaging import version ...@@ -13,8 +13,6 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, lax from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax import transformer_engine_jax
...@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import ( ...@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import (
) )
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import ( from .misc import (
check_valid_batch_dims, check_valid_batch_dims,
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
get_padded_spec, get_padded_spec,
get_cudnn_version, get_cudnn_version,
is_ffi_enabled,
) )
from ..sharding import ( from ..sharding import (
global_mesh_resource, global_mesh_resource,
...@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
Fused Attention Forward Primitive Fused Attention Forward Primitive
""" """
name = "te_fused_attn_forward" name = "te_fused_attn_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (13,) impl_static_args = (13,)
inner_primitive = None inner_primitive = None
...@@ -400,90 +396,40 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -400,90 +396,40 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled(): return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
name = "te_fused_attn_forward_ffi" ctx,
out = ffi.ffi_lowering(name)( q,
ctx, k,
q, v,
k, bias,
v, seed,
bias, q_cu_seqlen,
seed, kv_cu_seqlen,
q_cu_seqlen, q_seq_offsets,
kv_cu_seqlen, k_seq_offsets,
q_seq_offsets, _q_segment_ids,
k_seq_offsets, _kv_segment_ids,
_q_segment_ids, _q_segment_pos,
_kv_segment_ids, _kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
_q_segment_pos, input_batch=input_batch,
_kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering bias_batch=bias_batch,
input_batch=input_batch, q_max_seqlen=q_max_seqlen,
bias_batch=bias_batch, kv_max_seqlen=kv_max_seqlen,
q_max_seqlen=q_max_seqlen, attn_heads=attn_heads,
kv_max_seqlen=kv_max_seqlen, num_gqa_groups=num_gqa_groups,
attn_heads=attn_heads, bias_heads=bias_heads,
num_gqa_groups=num_gqa_groups, head_dim=head_dim,
bias_heads=bias_heads, max_segments_per_seq=config.max_segments_per_seq,
head_dim=head_dim, scaling_factor=float(config.scaling_factor),
max_segments_per_seq=config.max_segments_per_seq, dropout_probability=float(config.dropout_probability),
scaling_factor=float(config.scaling_factor), bias_type=int(config.attn_bias_type.value),
dropout_probability=float(config.dropout_probability), mask_type=int(config.attn_mask_type.value),
bias_type=int(config.attn_bias_type.value), qkv_layout=int(config.qkv_layout.value),
mask_type=int(config.attn_mask_type.value), is_training=config.is_training,
qkv_layout=int(config.qkv_layout.value), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
is_training=config.is_training, window_size_left=config.window_size[0],
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_right=config.window_size[1],
window_size_left=config.window_size[0], )
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod @staticmethod
def impl( def impl(
...@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
Fused Attention Backward Primitive Fused Attention Backward Primitive
""" """
name = "te_fused_attn_backward" name = "te_fused_attn_backward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (16,) impl_static_args = (16,)
inner_primitive = None inner_primitive = None
...@@ -813,96 +759,43 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -813,96 +759,43 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled(): return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
name = "te_fused_attn_backward_ffi" ctx,
out = ffi.ffi_lowering(name)( q,
ctx, k,
q, v,
k, bias,
v, softmax_aux,
bias, rng_state,
softmax_aux, output,
rng_state, doutput,
output, q_cu_seqlen,
doutput, kv_cu_seqlen,
q_cu_seqlen, q_seq_offsets,
kv_cu_seqlen, k_seq_offsets,
q_seq_offsets, q_segment_ids,
k_seq_offsets, kv_segment_ids,
q_segment_ids, q_segment_pos,
kv_segment_ids, kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
q_segment_pos, input_batch=input_batch,
kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering bias_batch=bias_batch,
input_batch=input_batch, q_max_seqlen=q_max_seqlen,
bias_batch=bias_batch, kv_max_seqlen=kv_max_seqlen,
q_max_seqlen=q_max_seqlen, attn_heads=attn_heads,
kv_max_seqlen=kv_max_seqlen, num_gqa_groups=num_gqa_groups,
attn_heads=attn_heads, bias_heads=bias_heads,
num_gqa_groups=num_gqa_groups, head_dim=head_dim,
bias_heads=bias_heads, max_segments_per_seq=config.max_segments_per_seq,
head_dim=head_dim, scaling_factor=float(config.scaling_factor),
max_segments_per_seq=config.max_segments_per_seq, dropout_probability=float(config.dropout_probability),
scaling_factor=float(config.scaling_factor), bias_type=int(config.attn_bias_type.value),
dropout_probability=float(config.dropout_probability), mask_type=int(config.attn_mask_type.value),
bias_type=int(config.attn_bias_type.value), qkv_layout=int(config.qkv_layout.value),
mask_type=int(config.attn_mask_type.value), is_training=config.is_training,
qkv_layout=int(config.qkv_layout.value), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
is_training=config.is_training, window_size_left=config.window_size[0],
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_right=config.window_size[1],
window_size_left=config.window_size[0], )
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
@staticmethod @staticmethod
def impl( def impl(
......
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import re import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from functools import partial from functools import partial
from packaging import version
from jax.extend import core from jax.extend import core
from jax.interpreters import xla, mlir from jax.interpreters import xla, mlir
...@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning ...@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src import dispatch from jax._src import dispatch
import jax
import transformer_engine_jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
class BasePrimitive(metaclass=ABCMeta): class BasePrimitive(metaclass=ABCMeta):
""" """
...@@ -120,3 +129,7 @@ def register_primitive(cls): ...@@ -120,3 +129,7 @@ def register_primitive(cls):
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
) )
cls.outer_primitive = outer_p cls.outer_primitive = outer_p
for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom call"""
from dataclasses import dataclass
from enum import IntEnum
from packaging import version
import jax
from jax.interpreters import mlir
import transformer_engine_jax
from .misc import is_ffi_enabled
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
try:
from jaxlib.hlo_helpers import custom_call
except ImportError:
# Newer JAX changed its API. But we want to support a few JAX
# version, so we still need this import.
pass
class CustomCallAPIVersion(IntEnum):
"""Enum for selecting between old and new custom call registration API"""
OPAQUE = 0
FFI = 1
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
@dataclass
class CustomCallArgsWrapper:
"""
wrapper of XLA custom call args
"""
def __init__(
self,
output_types,
operands,
operand_shapes,
operand_specific_layouts=None,
output_specific_layouts=None,
):
self.output_types = output_types
self.operands = operands
self.operand_layouts = CustomCallArgsWrapper.generate_layouts(
operand_shapes, operand_specific_layouts
)
output_shapes = [x.shape for x in output_types]
self.output_layouts = CustomCallArgsWrapper.generate_layouts(
output_shapes, output_specific_layouts
)
@staticmethod
def generate_layouts(shapes, specific_layouts):
"""
setup layouts for XLA custom call
"""
def default_layout(shape):
return range(len(shape) - 1, -1, -1)
if specific_layouts is None:
specific_layouts = {}
layouts = []
for idx, shape in enumerate(shapes):
if idx in specific_layouts:
layouts.append(specific_layouts[idx])
else:
layouts.append(default_layout(shape))
return layouts
def custom_caller(name, args, opaque, has_side_effect, **kwargs):
"""
XLA custom call warpper
"""
if hasattr(mlir, "custom_call"):
out = mlir.custom_call(
name,
result_types=args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
).results
else:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out = custom_call( # pylint: disable=too-many-function-args
name,
args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
)
return out
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/gemm.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type CublasHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets,
Dictionary attrs) {
nvte_cublas_handle_init();
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CublasHandleInitHandler, CublasHandleInitFFI,
FFI::Bind<FFI_Prepare>().RemainingArgs().RemainingRets().Attrs());
} // namespace jax
} // namespace transformer_engine
...@@ -13,8 +13,9 @@ namespace jax { ...@@ -13,8 +13,9 @@ namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186 // For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
switch (type) { switch (type) {
// Using this for E8M0
case xla::ffi::DataType::U8: case xla::ffi::DataType::U8:
return DType::kByte; return DType::kFloat8E8M0;
break; break;
case xla::ffi::DataType::S32: case xla::ffi::DataType::S32:
return DType::kInt32; return DType::kInt32;
...@@ -37,8 +38,12 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { ...@@ -37,8 +38,12 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E4M3FN: case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3; return DType::kFloat8E4M3;
break; break;
// case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0;
// break;
default: default:
auto type_num = static_cast<XLA_FFI_DataType>(type); auto type_num = static_cast<XLA_FFI_DataType>(type);
if (type_num == 33) return DType::kFloat8E8M0;
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num)); static_cast<int>(type_num));
break; break;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment