# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE miscellaneous for custom ops""" import os import functools from typing import Tuple from importlib.metadata import version as get_pkg_version from packaging.version import Version as PkgVersion import numpy as np import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import dtype_to_ir_type from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine import transformer_engine_jax from ..sharding import get_padded_spec as te_get_padded_spec def te_dtype_to_jax_dtype(te_dtype): """ convert TE dtype to jax dtype """ assert isinstance(te_dtype, TEDType) converter = { TEDType.kFloat32: jnp.float32, TEDType.kFloat16: jnp.float16, TEDType.kBFloat16: jnp.bfloat16, TEDType.kInt32: jnp.int32, TEDType.kInt64: jnp.int64, TEDType.kFloat8E4M3: jnp.float8_e4m3fn, TEDType.kFloat8E5M2: jnp.float8_e5m2, TEDType.kByte: jnp.uint8, } if te_dtype not in converter: raise ValueError(f"Unsupported {te_dtype=}") return converter.get(te_dtype) def te_dtype_to_ir_dtype(te_dtype): """ convert TE dtype to MLIR dtype """ return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype))) def jax_dtype_to_ir_dtype(jax_dtype): """ convert Jax dtype to MLIR dtype """ return dtype_to_ir_type(np.dtype(jax_dtype)) def jax_dtype_to_te_dtype(jax_dtype): """ convert jax dtype to TE dtype """ jax_dtype = dtypes.canonicalize_dtype(jax_dtype) converter = { jnp.float32.dtype: TEDType.kFloat32, jnp.float16.dtype: TEDType.kFloat16, jnp.bfloat16.dtype: TEDType.kBFloat16, jnp.int32.dtype: TEDType.kInt32, jnp.int64.dtype: TEDType.kInt64, jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3, jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2, jnp.uint8.dtype: TEDType.kByte, } if jax_dtype not in converter: raise ValueError(f"Unsupported {jax_dtype=}") return converter.get(jax_dtype) def get_padded_spec(arg_info): """ Get padded spec for partitioning from arguments' information """ if arg_info.sharding is None: return te_get_padded_spec(None, arg_info.ndim) ndim, spec = arg_info.ndim, arg_info.sharding.spec return te_get_padded_spec(spec, ndim) def check_valid_batch_dims(bdims): """ Assert out non-supported bath dims """ for dim in bdims: assert dim in [0, None], f"Currently only support batch_dim in [0, None], but got {dim=}" def normalize_axis_boundary(axis, ndim): """NA""" return axis if axis >= 0 else ndim + axis def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): """ te_cast_transpose_p multi-dims transpose static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be involved into transpose, -1 means all axes involve into transpose. transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary examples: X in shape (dim0, dim1, dim2, dim3, dim4) static_axis_boundary == -1, transpose_axis_boundary == 2 Xt = (dim2, dim3, dim4, dim0, dim1) static_axis_boundary == 0, transpose_axis_boundary == 2 Xt = (dim0, dim2, dim3, dim4, dim1) static_axis_boundary == 0, transpose_axis_boundary == 3 Xt = (dim0, dim3, dim4, dim1. dim2) """ if static_axis_boundary < 0: static_axis_boundary = -1 # means no static axes assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. transpose_start_idx = static_axis_boundary + 1 transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape)) assert transpose_start_idx < transpose_axis_boundary return ( *shape[:transpose_start_idx], *shape[transpose_axis_boundary:], *shape[transpose_start_idx:transpose_axis_boundary], ) @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" encoded_version = transformer_engine_jax.get_cudnn_version() major_version_magnitude = 1000 if encoded_version < 90000 else 10000 major, encoded_version = divmod(encoded_version, major_version_magnitude) minor, patch = divmod(encoded_version, 100) return (major, minor, patch) @functools.lru_cache(maxsize=None) def jax_version_meet_requirement(version: str): """ Helper function checking if required JAX version is available """ jax_version = PkgVersion(get_pkg_version("jax")) jax_version_required = PkgVersion(version) return jax_version >= jax_version_required def is_ffi_enabled(): """ Helper function checking if XLA Custom Call with FFI is enabled """ is_supported = jax_version_meet_requirement("0.4.35") # New APIs with FFI are enabled by default is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1")) assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value" return is_supported and is_enabled def get_xla_flag(flag: str, default=None, cast=str): """ Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. """ xla_flags = [] if xla_flags_env := os.getenv("XLA_FLAGS"): xla_flags.extend(xla_flags_env.split()) for flag_i in sorted(xla_flags): if "=" in flag_i: # option like --xla_abc=foo name, val = flag_i.split("=", 2) if name == flag: return val if cast is None else cast(val) else: # flag like --xla_enable_foo name, val = flag_i, None if name == flag: return True return default