Commit a207db1d authored by yuguo's avatar yuguo
Browse files
parents fbee8990 69365f88
......@@ -201,8 +201,9 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
max_fp8 = Quantized_Limits<DType>::max_norm;);
// Update scale
compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8,
config.force_pow_2_scales, config.amax_epsilon);
compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
config.amax_epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
"""Transformer Engine bindings for JAX.
This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.
The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training
All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
# pylint: disable=wrong-import-position,wrong-import-order
import sys
import logging
import importlib
import importlib.util
import ctypes
from importlib.metadata import version
import sys
from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
......@@ -41,7 +55,7 @@ def _load_library():
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
_logger.info(
logging.info(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'",
module_name,
......@@ -67,8 +81,10 @@ def _load_library():
_load_library()
from . import flax
from .fp8 import fp8_autocast, update_collections, get_delayed_scaling
from .fp8 import NVTE_FP8_COLLECTION_NAME
from . import quantize
from .quantize import fp8_autocast
from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType
......@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper(
)
__all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource",
"MajorShardingType",
"ShardingResource",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.
This module provides optimized activation functions with quantization support.
"""
from typing import Sequence, Union, Callable, Optional
from functools import partial
import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize.tensor import ScaledTensor
from .quantize.quantizer import Quantizer
def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor.
It supports string-based activation types (e.g., 'relu', 'gelu', ('gelu', 'linear')).
Args:
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
return output
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
"""Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for
custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
return _output
def _activation_fwd_rule(x, activation_type, quantizer):
"""Forward pass rule for activation function.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
if isinstance(fwd_output, ScaledTensor):
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)
def _activation_bwd_rule(activation_type, ctx, g):
"""Backward pass rule for activation function.
Args:
activation_type: Sequence of activation functions
ctx: Context from forward pass
g: Gradient from upstream
Returns:
Gradient with respect to input
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None)
_activation.defvjp(_activation_fwd_rule, _activation_bwd_rule)
......@@ -378,6 +378,44 @@ def _mask_to_seqlens_offset(mask, max_segments_per_seq):
return q_seqlen, q_offset, kv_seqlen, kv_offset
def _fast_causal_adjust_seqlen_and_offsets(
segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
):
# The assumption is that for any segment tokens respect causal ordering except at the ends
# of the segment. This allows us to tweak the length and offset by only looking at the start
# and end tokens between segments.
is_active_segment = jnp.logical_and(q_len > 0, kv_len > 0)
q_seq_id_start = jnp.take(segment_pos_q, q_offset[..., :-1], fill_value=-1)
kv_seq_id_start = jnp.take(segment_pos_kv, kv_offset[..., :-1], fill_value=-1)
skip_start_token = jnp.logical_and(kv_seq_id_start > q_seq_id_start, is_active_segment).astype(
jnp.int32
)
q_len -= skip_start_token
q_offset += jnp.insert(skip_start_token, skip_start_token.shape[-1], 0, axis=-1)
q_seq_id_end = jnp.take(segment_pos_q, q_offset[..., 1:] - 1, fill_value=-1)
kv_seq_id_end = jnp.take(segment_pos_kv, kv_offset[..., 1:] - 1, fill_value=-1)
skip_end_token = jnp.logical_and(kv_seq_id_end > q_seq_id_end, is_active_segment).astype(
jnp.int32
)
kv_len -= skip_end_token
return q_len, kv_len, q_offset, kv_offset
def _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
):
q_len, q_offset = _get_seqlens_and_offsets(segment_ids_q, max_segments_per_seq)
kv_len, kv_offset = _get_seqlens_and_offsets(segment_ids_kv, max_segments_per_seq)
return _fast_causal_adjust_seqlen_and_offsets(
segment_pos_q, q_len, q_offset, segment_pos_kv, kv_len, kv_offset
)
def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q,
segment_ids_kv,
......@@ -387,6 +425,25 @@ def _segment_ids_pos_to_seqlens_offsets(
window_size,
max_segments_per_seq,
):
# TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here.
# Computing the full mask is expensive due to quadratic expansion of Q * KV masking.
# Assumptions for cudnn causal mask correctness.
# 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0]
# 2. No intra-segment padding, only inter-segment paddding allowed
# 3. Only start or end token within a segment may violate the causal order relationship
# 1 5 9 0 4 8 10 0 4 8
# 0 x x
# 4 x x x x x
# 8 x x x x x x x x
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1):
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)
# (1 = attend, 0 = masked)
segment_mask = make_attention_mask(
segment_ids_q,
......
......@@ -7,4 +7,4 @@ from .attention import *
from .normalization import *
from .quantization import *
from .softmax import *
from .transpose import *
from .gemm import *
......@@ -13,8 +13,6 @@ from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine_jax
......@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import (
)
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
check_valid_batch_dims,
jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype,
get_padded_spec,
get_cudnn_version,
is_ffi_enabled,
)
from ..sharding import (
global_mesh_resource,
......@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
Fused Attention Forward Primitive
"""
name = "te_fused_attn_forward"
name = "te_fused_attn_forward_ffi"
multiple_results = True
impl_static_args = (13,)
inner_primitive = None
......@@ -400,90 +396,40 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled():
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
return out
return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
ctx,
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
@staticmethod
def impl(
......@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
Fused Attention Backward Primitive
"""
name = "te_fused_attn_backward"
name = "te_fused_attn_backward_ffi"
multiple_results = True
impl_static_args = (16,)
inner_primitive = None
......@@ -813,96 +759,43 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled():
name = "te_fused_attn_backward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out
return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
ctx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
@staticmethod
def impl(
......
......@@ -6,6 +6,7 @@ import os
import re
from abc import ABCMeta, abstractmethod
from functools import partial
from packaging import version
from jax.extend import core
from jax.interpreters import xla, mlir
......@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
from jax._src import dispatch
import jax
import transformer_engine_jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
class BasePrimitive(metaclass=ABCMeta):
"""
......@@ -120,3 +129,7 @@ def register_primitive(cls):
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
)
cls.outer_primitive = outer_p
for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom call"""
from dataclasses import dataclass
from enum import IntEnum
from packaging import version
import jax
from jax.interpreters import mlir
import transformer_engine_jax
from .misc import is_ffi_enabled
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
try:
from jaxlib.hlo_helpers import custom_call
except ImportError:
# Newer JAX changed its API. But we want to support a few JAX
# version, so we still need this import.
pass
class CustomCallAPIVersion(IntEnum):
"""Enum for selecting between old and new custom call registration API"""
OPAQUE = 0
FFI = 1
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
@dataclass
class CustomCallArgsWrapper:
"""
wrapper of XLA custom call args
"""
def __init__(
self,
output_types,
operands,
operand_shapes,
operand_specific_layouts=None,
output_specific_layouts=None,
):
self.output_types = output_types
self.operands = operands
self.operand_layouts = CustomCallArgsWrapper.generate_layouts(
operand_shapes, operand_specific_layouts
)
output_shapes = [x.shape for x in output_types]
self.output_layouts = CustomCallArgsWrapper.generate_layouts(
output_shapes, output_specific_layouts
)
@staticmethod
def generate_layouts(shapes, specific_layouts):
"""
setup layouts for XLA custom call
"""
def default_layout(shape):
return range(len(shape) - 1, -1, -1)
if specific_layouts is None:
specific_layouts = {}
layouts = []
for idx, shape in enumerate(shapes):
if idx in specific_layouts:
layouts.append(specific_layouts[idx])
else:
layouts.append(default_layout(shape))
return layouts
def custom_caller(name, args, opaque, has_side_effect, **kwargs):
"""
XLA custom call warpper
"""
if hasattr(mlir, "custom_call"):
out = mlir.custom_call(
name,
result_types=args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
).results
else:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out = custom_call( # pylint: disable=too-many-function-args
name,
args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs,
)
return out
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/gemm.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type CublasHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets,
Dictionary attrs) {
nvte_cublas_handle_init();
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CublasHandleInitHandler, CublasHandleInitFFI,
FFI::Bind<FFI_Prepare>().RemainingArgs().RemainingRets().Attrs());
} // namespace jax
} // namespace transformer_engine
......@@ -13,8 +13,9 @@ namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
switch (type) {
// Using this for E8M0
case xla::ffi::DataType::U8:
return DType::kByte;
return DType::kFloat8E8M0;
break;
case xla::ffi::DataType::S32:
return DType::kInt32;
......@@ -37,8 +38,12 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3;
break;
// case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0;
// break;
default:
auto type_num = static_cast<XLA_FFI_DataType>(type);
if (type_num == 33) return DType::kFloat8E8M0;
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num));
break;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment