misc.py 5.93 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX/TE miscellaneous for custom ops"""

6
import os
7
8
import functools
from typing import Tuple
9
10
from importlib.metadata import version as get_pkg_version
from packaging.version import Version as PkgVersion
11

12
import numpy as np
13

14
15
16
17
18
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import dtype_to_ir_type

from transformer_engine.transformer_engine_jax import DType as TEDType
19
from transformer_engine import transformer_engine_jax
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

from ..sharding import get_padded_spec as te_get_padded_spec


def te_dtype_to_jax_dtype(te_dtype):
    """
    convert TE dtype to jax dtype
    """
    assert isinstance(te_dtype, TEDType)

    converter = {
        TEDType.kFloat32: jnp.float32,
        TEDType.kFloat16: jnp.float16,
        TEDType.kBFloat16: jnp.bfloat16,
        TEDType.kInt32: jnp.int32,
        TEDType.kInt64: jnp.int64,
        TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
        TEDType.kFloat8E5M2: jnp.float8_e5m2,
38
        TEDType.kByte: jnp.uint8,
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    }

    if te_dtype not in converter:
        raise ValueError(f"Unsupported {te_dtype=}")

    return converter.get(te_dtype)


def te_dtype_to_ir_dtype(te_dtype):
    """
    convert TE dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))


def jax_dtype_to_ir_dtype(jax_dtype):
    """
    convert Jax dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(jax_dtype))


def jax_dtype_to_te_dtype(jax_dtype):
    """
    convert jax dtype to TE dtype
    """
    jax_dtype = dtypes.canonicalize_dtype(jax_dtype)

    converter = {
        jnp.float32.dtype: TEDType.kFloat32,
        jnp.float16.dtype: TEDType.kFloat16,
        jnp.bfloat16.dtype: TEDType.kBFloat16,
        jnp.int32.dtype: TEDType.kInt32,
        jnp.int64.dtype: TEDType.kInt64,
        jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
        jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
        jnp.uint8.dtype: TEDType.kByte,
    }

    if jax_dtype not in converter:
        raise ValueError(f"Unsupported {jax_dtype=}")

    return converter.get(jax_dtype)


def get_padded_spec(arg_info):
    """
    Get padded spec for partitioning from arguments' information
    """
    if arg_info.sharding is None:
        return te_get_padded_spec(None, arg_info.ndim)
    ndim, spec = arg_info.ndim, arg_info.sharding.spec
    return te_get_padded_spec(spec, ndim)


def check_valid_batch_dims(bdims):
    """
    Assert out non-supported bath dims
    """
    for dim in bdims:
99
100
        assert dim in [0, None], f"Currently only support batch_dim in [0, None], but got {dim=}"

101
102

def normalize_axis_boundary(axis, ndim):
103
    """NA"""
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    return axis if axis >= 0 else ndim + axis


def multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
    """
    te_cast_transpose_p multi-dims transpose

    static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
        involved into transpose, -1 means all axes involve into transpose.
    transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
        transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary

    examples:
        X in shape (dim0, dim1, dim2, dim3, dim4)

        static_axis_boundary == -1, transpose_axis_boundary == 2
            Xt = (dim2, dim3, dim4, dim0, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 2
            Xt = (dim0, dim2, dim3, dim4, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 3
            Xt = (dim0, dim3, dim4, dim1. dim2)
    """
    if static_axis_boundary < 0:
129
130
        static_axis_boundary = -1  # means no static axes
    assert static_axis_boundary < len(shape) - 2  # at least 2 remaining for transpose.
131
132
133
    transpose_start_idx = static_axis_boundary + 1
    transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape))
    assert transpose_start_idx < transpose_axis_boundary
134
135
136
137
138
    return (
        *shape[:transpose_start_idx],
        *shape[transpose_axis_boundary:],
        *shape[transpose_start_idx:transpose_axis_boundary],
    )
139
140
141
142
143
144
145
146
147
148


@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
    """Runtime cuDNN version (major, minor, patch)"""
    encoded_version = transformer_engine_jax.get_cudnn_version()
    major_version_magnitude = 1000 if encoded_version < 90000 else 10000
    major, encoded_version = divmod(encoded_version, major_version_magnitude)
    minor, patch = divmod(encoded_version, 100)
    return (major, minor, patch)
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164


@functools.lru_cache(maxsize=None)
def jax_version_meet_requirement(version: str):
    """
    Helper function checking if required JAX version is available
    """
    jax_version = PkgVersion(get_pkg_version("jax"))
    jax_version_required = PkgVersion(version)
    return jax_version >= jax_version_required


def is_ffi_enabled():
    """
    Helper function checking if XLA Custom Call with FFI is enabled
    """
165
    is_supported = jax_version_meet_requirement("0.4.35")
166
167
168
169
    # New APIs with FFI are enabled by default
    is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
    assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
    return is_supported and is_enabled
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191


def get_xla_flag(flag: str, default=None, cast=str):
    """
    Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
    """
    xla_flags = []
    if xla_flags_env := os.getenv("XLA_FLAGS"):
        xla_flags.extend(xla_flags_env.split())

    for flag_i in sorted(xla_flags):
        if "=" in flag_i:
            # option like --xla_abc=foo
            name, val = flag_i.split("=", 2)
            if name == flag:
                return val if cast is None else cast(val)
        else:
            # flag like --xla_enable_foo
            name, val = flag_i, None
            if name == flag:
                return True
    return default