# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for quantization""" from typing import Tuple import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax import ffi import transformer_engine_jax from transformer_engine_jax import DType as TEDType from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper from .misc import ( get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, jax_dtype_to_ir_dtype, is_ffi_enabled, ) from ..sharding import all_reduce_max_along_all_axes_except_PP __all__ = ["cast_fp8"] def _jax_quantize(x, scale, q_dtype): """ Quantize with scale """ compute_dtype = scale.dtype dtype_max = (jnp.finfo(q_dtype).max).astype(compute_dtype) scaled_x = x.astype(compute_dtype) * scale clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max) return clipped_scaled_x.astype(q_dtype) def _jax_cast_fp8(inputs, scale, amax, out_dtype): """ JAX native fp8 casting implementation """ casted_output = _jax_quantize(inputs, scale, q_dtype=out_dtype) updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype)) return casted_output, updated_amax class CastFP8Primitive(BasePrimitive): """ Cast Primitive """ name = "te_quantize" multiple_results = True impl_static_args = (4,) inner_primitive = None outer_primitive = None @staticmethod def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): """ te_cast abstract """ dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) return casted_x_aval, updated_amax_aval @staticmethod def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): """ te_cast lowering rules """ x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 if is_ffi_enabled(): name = "te_quantize_ffi" out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})( ctx, x, amax, scale, scale_inv ) else: ir_x_type = ir.RankedTensorType(x.type) ir_x_shape = ir_x_type.shape ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_amax_type = ir.RankedTensorType(amax.type) ir_amax_dtype = ir_amax_type.element_type ir_amax_shape = ir_amax_type.shape ir_scale_shape = ir_amax_shape ir_scale_inv_shape = ir_amax_shape out_types = [ ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ] operands = [x, amax, scale, scale_inv] operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) opaque = transformer_engine_jax.pack_common_descriptor( ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype) ) out = custom_caller( CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} ) return out @staticmethod def impl(x, amax, scale, scale_inv, out_dtype): """ te_cast implementation """ assert CastFP8Primitive.inner_primitive is not None casted_x, updated_amax = CastFP8Primitive.inner_primitive.bind( x, amax, scale, scale_inv, out_dtype=out_dtype ) return casted_x, updated_amax @staticmethod def batcher(batched_args, batch_dims, *, out_dtype): check_valid_batch_dims(batch_dims) assert CastFP8Primitive.outer_primitive is not None x, amax, scale, scale_inv = batched_args x_bdim, amax_bdim, *_ = batch_dims out_bdims = x_bdim, amax_bdim return ( CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype), out_bdims, ) @staticmethod def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): del out_dtype, result_infos x_spec = get_padded_spec(arg_infos[0]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) return (casted_x_sharding, amax_sharding) @staticmethod def partition(out_dtype, mesh, arg_infos, result_infos): del result_infos x_spec = get_padded_spec(arg_infos[0]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = (casted_x_sharding, amax_sharding) def sharded_impl(x, amax, scale, scale_inv): local_cx, local_updated_amax = CastFP8Primitive.impl( x, amax, scale, scale_inv, out_dtype=out_dtype ) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh) return local_cx, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings register_primitive(CastFP8Primitive) def cast_fp8( x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, out_dtype: TEDType, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Cast wrapper Return FP8 tensor """ if not CastFP8Primitive.enabled(): return _jax_cast_fp8(x, scale, amax, out_dtype=out_dtype) return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)