misc.py 9.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX/TE miscellaneous for custom ops"""

6
import os
7
8
import functools
from typing import Tuple
9
10
from importlib.metadata import version as get_pkg_version
from packaging.version import Version as PkgVersion
11

12
import numpy as np
13

14
import jax
15
from jax import dtypes
16
import jax.numpy as jnp
17
18
from jax.interpreters.mlir import dtype_to_ir_type

19
import transformer_engine_jax
20
21

from ..sharding import get_padded_spec as te_get_padded_spec
22
from ..quantize import ScaledTensorFactory, QuantizeLayout
23
24

TEDType = transformer_engine_jax.DType
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


def te_dtype_to_jax_dtype(te_dtype):
    """
    convert TE dtype to jax dtype
    """
    assert isinstance(te_dtype, TEDType)

    converter = {
        TEDType.kFloat32: jnp.float32,
        TEDType.kFloat16: jnp.float16,
        TEDType.kBFloat16: jnp.bfloat16,
        TEDType.kInt32: jnp.int32,
        TEDType.kInt64: jnp.int64,
        TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
        TEDType.kFloat8E5M2: jnp.float8_e5m2,
41
        TEDType.kByte: jnp.uint8,
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    }

    if te_dtype not in converter:
        raise ValueError(f"Unsupported {te_dtype=}")

    return converter.get(te_dtype)


def te_dtype_to_ir_dtype(te_dtype):
    """
    convert TE dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))


def jax_dtype_to_ir_dtype(jax_dtype):
    """
    convert Jax dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(jax_dtype))


def jax_dtype_to_te_dtype(jax_dtype):
    """
    convert jax dtype to TE dtype
    """
    jax_dtype = dtypes.canonicalize_dtype(jax_dtype)

    converter = {
        jnp.float32.dtype: TEDType.kFloat32,
        jnp.float16.dtype: TEDType.kFloat16,
        jnp.bfloat16.dtype: TEDType.kBFloat16,
        jnp.int32.dtype: TEDType.kInt32,
        jnp.int64.dtype: TEDType.kInt64,
        jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
        jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
        jnp.uint8.dtype: TEDType.kByte,
    }

    if jax_dtype not in converter:
        raise ValueError(f"Unsupported {jax_dtype=}")

    return converter.get(jax_dtype)


def get_padded_spec(arg_info):
    """
    Get padded spec for partitioning from arguments' information
    """
    if arg_info.sharding is None:
        return te_get_padded_spec(None, arg_info.ndim)
    ndim, spec = arg_info.ndim, arg_info.sharding.spec
    return te_get_padded_spec(spec, ndim)


def check_valid_batch_dims(bdims):
    """
    Assert out non-supported bath dims
    """
    for dim in bdims:
102
103
        assert dim in [0, None], f"Currently only support batch_dim in [0, None], but got {dim=}"

104
105

def normalize_axis_boundary(axis, ndim):
106
    """NA"""
107
108
109
    return axis if axis >= 0 else ndim + axis


110
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
111
112
113
114
115
    """
    te_cast_transpose_p multi-dims transpose

    static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
        involved into transpose, -1 means all axes involve into transpose.
116
117
    transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for
        transpose. Note, transpose_axis should be greater than static_axis_boundary
118
119
120
121

    examples:
        X in shape (dim0, dim1, dim2, dim3, dim4)

122
        static_axis_boundary == -1, transpose_axis == 2
123
124
            Xt = (dim2, dim3, dim4, dim0, dim1)

125
        static_axis_boundary == 0, transpose_axis == 2
126
127
            Xt = (dim0, dim2, dim3, dim4, dim1)

128
        static_axis_boundary == 0, transpose_axis == 3
129
130
131
            Xt = (dim0, dim3, dim4, dim1. dim2)
    """
    if static_axis_boundary < 0:
132
133
        static_axis_boundary = -1  # means no static axes
    assert static_axis_boundary < len(shape) - 2  # at least 2 remaining for transpose.
134
    transpose_start_idx = static_axis_boundary + 1
135
136
    transpose_axis = normalize_axis_boundary(transpose_axis, len(shape))
    assert transpose_start_idx < transpose_axis
137
138
    return (
        *shape[:transpose_start_idx],
139
140
        *shape[transpose_axis:],
        *shape[transpose_start_idx:transpose_axis],
141
    )
142
143
144
145
146
147
148
149
150
151


@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
    """Runtime cuDNN version (major, minor, patch)"""
    encoded_version = transformer_engine_jax.get_cudnn_version()
    major_version_magnitude = 1000 if encoded_version < 90000 else 10000
    major, encoded_version = divmod(encoded_version, major_version_magnitude)
    minor, patch = divmod(encoded_version, 100)
    return (major, minor, patch)
152
153
154
155
156
157
158
159
160
161
162
163


@functools.lru_cache(maxsize=None)
def jax_version_meet_requirement(version: str):
    """
    Helper function checking if required JAX version is available
    """
    jax_version = PkgVersion(get_pkg_version("jax"))
    jax_version_required = PkgVersion(version)
    return jax_version >= jax_version_required


164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def get_xla_flag(flag: str, default=None, cast=str):
    """
    Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
    """
    xla_flags = []
    if xla_flags_env := os.getenv("XLA_FLAGS"):
        xla_flags.extend(xla_flags_env.split())

    for flag_i in sorted(xla_flags):
        if "=" in flag_i:
            # option like --xla_abc=foo
            name, val = flag_i.split("=", 2)
            if name == flag:
                return val if cast is None else cast(val)
        else:
            # flag like --xla_enable_foo
            name, val = flag_i, None
            if name == flag:
                return True
    return default
184
185


186
187
188
189
190
191
192
193
194
195
def get_min_device_compute_capability():
    """
    Returns the minimum compute capability of all local devices.
    """
    return min(
        transformer_engine_jax.get_device_compute_capability(local_gpu_id)
        for local_gpu_id in range(len(jax.local_devices()))
    )


196
197
198
199
200
201
202
203
204
205
def get_all_device_compute_capability():
    """
    Returns a list of compute capability of all local devices.
    """
    return tuple(
        transformer_engine_jax.get_device_compute_capability(local_gpu_id)
        for local_gpu_id in range(len(jax.local_devices()))
    )


206
207
208
209
210
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
    """
    Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
    calculate dbias separately. This function checks if the workaround should be applied.
    """
Alp Dener's avatar
Alp Dener committed
211
212
213
    if quantizer is None:
        return False

214
215
216
217
218
    arch_l_100 = False
    for local_gpu_id in range(len(jax.local_devices())):
        if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100:
            arch_l_100 = True
            break
Alp Dener's avatar
Alp Dener committed
219
220
221
    # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
    # but this fails when bias fusion is turned on with arch < 100.
    force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
222
    return (
Alp Dener's avatar
Alp Dener committed
223
        (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
224
225
226
227
228
        and arch_l_100
        and is_dbias
    )


229
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs):
230
231
232
233
234
235
236
237
238
239
240
241
242
    """
    Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
    It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.

    If 'f' returns a tuple, the first output must be the only ScaledTensor output.

    @param f: function to call
    @param args: positional arguments to pass to 'f'
    @param quantizer: quantizer to use
    @param kwargs: keyword arguments to pass to 'f'
    @return: the output of 'f' with the colwise output calculated
    """
    should_apply_war = (
243
        quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
244
245
246
247
248
249
    )
    if not should_apply_war:
        return None

    # 2x is not supported by TE kernels for delayed scaling
    # so revert to 1x and transpose in JAX
250
    quantizer.q_layout = QuantizeLayout.ROWWISE
251
252
253
254
255
    rowwise = f(*args, **kwargs, quantizer=quantizer)
    other_outputs = None
    if isinstance(rowwise, tuple):
        other_outputs = rowwise[1:]
        rowwise = rowwise[0]
256
257
258
259
260
261
262
    quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE
    if flatten_axis < 0:
        flatten_axis += rowwise.data.ndim
    assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds"
    colwise_data = jnp.transpose(
        rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis))
    )
263
264
265
266
267
268
269
    output_2x = ScaledTensorFactory.create(
        data=rowwise.data,
        scale_inv=rowwise.scale_inv,
        colwise_data=colwise_data,
        colwise_scale_inv=rowwise.scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=rowwise.dq_dtype,
270
271
272
        q_layout=QuantizeLayout.ROWWISE_COLWISE,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    )
    if other_outputs is not None:
        return (output_2x,) + other_outputs
    return output_2x


class NamedSharding(jax.sharding.NamedSharding):
    """
    Wrapper around jax.sharding.NamedSharding that adds a string description field as metadata for easier debugging.
    """

    def __init__(self, *args, desc: str = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.desc = desc

    def __repr__(self):
        return f"NamedSharding({self.mesh}, {self.spec}, desc={self.desc})"

    def duplicate_with_new_description(self, desc: str):
        """
        Create a new NamedSharding with the same mesh and spec but with a new description.
        """
        return NamedSharding(self.mesh, self.spec, desc=desc)
Phuong Nguyen's avatar
Phuong Nguyen committed
296
297
298
299
300
301
302
303


@functools.lru_cache(maxsize=1)
def is_all_reduce_in_float32():
    """
    Check if all-reduce is in float32
    """
    return os.getenv("NVTE_JAX_ALL_REDUCE_IN_FP32", "0") == "1"