misc.py 9.46 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX/TE miscellaneous for custom ops"""

6
import os
7
8
9
import functools
from typing import Tuple

10
import numpy as np
11

12
import jax
13
from jax import dtypes
14
import jax.numpy as jnp
15
16
from jax.interpreters.mlir import dtype_to_ir_type

17
import transformer_engine_jax
18
19

from ..sharding import get_padded_spec as te_get_padded_spec
20
from ..quantize import ScaledTensorFactory, QuantizeLayout
21
22

TEDType = transformer_engine_jax.DType
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


def te_dtype_to_jax_dtype(te_dtype):
    """
    convert TE dtype to jax dtype
    """
    assert isinstance(te_dtype, TEDType)

    converter = {
        TEDType.kFloat32: jnp.float32,
        TEDType.kFloat16: jnp.float16,
        TEDType.kBFloat16: jnp.bfloat16,
        TEDType.kInt32: jnp.int32,
        TEDType.kInt64: jnp.int64,
        TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
        TEDType.kFloat8E5M2: jnp.float8_e5m2,
39
        TEDType.kByte: jnp.uint8,
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    }

    if te_dtype not in converter:
        raise ValueError(f"Unsupported {te_dtype=}")

    return converter.get(te_dtype)


def te_dtype_to_ir_dtype(te_dtype):
    """
    convert TE dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))


def jax_dtype_to_ir_dtype(jax_dtype):
    """
    convert Jax dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(jax_dtype))


def jax_dtype_to_te_dtype(jax_dtype):
    """
    convert jax dtype to TE dtype
    """
    jax_dtype = dtypes.canonicalize_dtype(jax_dtype)

    converter = {
        jnp.float32.dtype: TEDType.kFloat32,
        jnp.float16.dtype: TEDType.kFloat16,
        jnp.bfloat16.dtype: TEDType.kBFloat16,
        jnp.int32.dtype: TEDType.kInt32,
        jnp.int64.dtype: TEDType.kInt64,
        jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
        jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
76
77
        jnp.float8_e8m0fnu.dtype: TEDType.kFloat8E8M0,
        jnp.float4_e2m1fn.dtype: TEDType.kFloat4E2M1,
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    }

    if jax_dtype not in converter:
        raise ValueError(f"Unsupported {jax_dtype=}")

    return converter.get(jax_dtype)


def get_padded_spec(arg_info):
    """
    Get padded spec for partitioning from arguments' information
    """
    if arg_info.sharding is None:
        return te_get_padded_spec(None, arg_info.ndim)
    ndim, spec = arg_info.ndim, arg_info.sharding.spec
    return te_get_padded_spec(spec, ndim)


def check_valid_batch_dims(bdims):
    """
    Assert out non-supported bath dims
    """
    for dim in bdims:
101
102
        assert dim in [0, None], f"Currently only support batch_dim in [0, None], but got {dim=}"

103
104

def normalize_axis_boundary(axis, ndim):
105
    """NA"""
106
107
108
    return axis if axis >= 0 else ndim + axis


109
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
110
111
112
113
114
    """
    te_cast_transpose_p multi-dims transpose

    static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
        involved into transpose, -1 means all axes involve into transpose.
115
116
    transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for
        transpose. Note, transpose_axis should be greater than static_axis_boundary
117
118
119
120

    examples:
        X in shape (dim0, dim1, dim2, dim3, dim4)

121
        static_axis_boundary == -1, transpose_axis == 2
122
123
            Xt = (dim2, dim3, dim4, dim0, dim1)

124
        static_axis_boundary == 0, transpose_axis == 2
125
126
            Xt = (dim0, dim2, dim3, dim4, dim1)

127
        static_axis_boundary == 0, transpose_axis == 3
128
129
130
            Xt = (dim0, dim3, dim4, dim1. dim2)
    """
    if static_axis_boundary < 0:
131
132
        static_axis_boundary = -1  # means no static axes
    assert static_axis_boundary < len(shape) - 2  # at least 2 remaining for transpose.
133
    transpose_start_idx = static_axis_boundary + 1
134
135
    transpose_axis = normalize_axis_boundary(transpose_axis, len(shape))
    assert transpose_start_idx < transpose_axis
136
137
    return (
        *shape[:transpose_start_idx],
138
139
        *shape[transpose_axis:],
        *shape[transpose_start_idx:transpose_axis],
140
    )
141
142
143
144
145
146
147
148
149
150


@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
    """Runtime cuDNN version (major, minor, patch)"""
    encoded_version = transformer_engine_jax.get_cudnn_version()
    major_version_magnitude = 1000 if encoded_version < 90000 else 10000
    major, encoded_version = divmod(encoded_version, major_version_magnitude)
    minor, patch = divmod(encoded_version, 100)
    return (major, minor, patch)
151
152


153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def get_xla_flag(flag: str, default=None, cast=str):
    """
    Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
    """
    xla_flags = []
    if xla_flags_env := os.getenv("XLA_FLAGS"):
        xla_flags.extend(xla_flags_env.split())

    for flag_i in sorted(xla_flags):
        if "=" in flag_i:
            # option like --xla_abc=foo
            name, val = flag_i.split("=", 2)
            if name == flag:
                return val if cast is None else cast(val)
        else:
            # flag like --xla_enable_foo
            name, val = flag_i, None
            if name == flag:
                return True
    return default
173
174


175
176
177
178
179
180
181
182
183
184
def get_min_device_compute_capability():
    """
    Returns the minimum compute capability of all local devices.
    """
    return min(
        transformer_engine_jax.get_device_compute_capability(local_gpu_id)
        for local_gpu_id in range(len(jax.local_devices()))
    )


185
186
187
188
189
190
191
192
193
194
def get_all_device_compute_capability():
    """
    Returns a list of compute capability of all local devices.
    """
    return tuple(
        transformer_engine_jax.get_device_compute_capability(local_gpu_id)
        for local_gpu_id in range(len(jax.local_devices()))
    )


195
196
197
198
199
def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None):
    """
    Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
    calculate dbias separately. This function checks if the workaround should be applied.
    """
Alp Dener's avatar
Alp Dener committed
200
201
202
    if quantizer is None:
        return False

203
204
205
206
207
    arch_l_100 = False
    for local_gpu_id in range(len(jax.local_devices())):
        if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100:
            arch_l_100 = True
            break
Alp Dener's avatar
Alp Dener committed
208
209
210
    # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
    # but this fails when bias fusion is turned on with arch < 100.
    force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
211
    return (
Alp Dener's avatar
Alp Dener committed
212
        (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
213
214
215
216
217
        and arch_l_100
        and is_dbias
    )


218
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs):
219
220
221
222
223
224
225
226
227
228
229
230
231
    """
    Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
    It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.

    If 'f' returns a tuple, the first output must be the only ScaledTensor output.

    @param f: function to call
    @param args: positional arguments to pass to 'f'
    @param quantizer: quantizer to use
    @param kwargs: keyword arguments to pass to 'f'
    @return: the output of 'f' with the colwise output calculated
    """
    should_apply_war = (
232
        quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
233
234
235
236
237
238
    )
    if not should_apply_war:
        return None

    # 2x is not supported by TE kernels for delayed scaling
    # so revert to 1x and transpose in JAX
239
    quantizer.q_layout = QuantizeLayout.ROWWISE
240
241
242
243
244
    rowwise = f(*args, **kwargs, quantizer=quantizer)
    other_outputs = None
    if isinstance(rowwise, tuple):
        other_outputs = rowwise[1:]
        rowwise = rowwise[0]
245
246
247
248
249
250
251
    quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE
    if flatten_axis < 0:
        flatten_axis += rowwise.data.ndim
    assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds"
    colwise_data = jnp.transpose(
        rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis))
    )
252
253
254
255
256
257
258
    output_2x = ScaledTensorFactory.create(
        data=rowwise.data,
        scale_inv=rowwise.scale_inv,
        colwise_data=colwise_data,
        colwise_scale_inv=rowwise.scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=rowwise.dq_dtype,
259
260
261
        q_layout=QuantizeLayout.ROWWISE_COLWISE,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    )
    if other_outputs is not None:
        return (output_2x,) + other_outputs
    return output_2x


class NamedSharding(jax.sharding.NamedSharding):
    """
    Wrapper around jax.sharding.NamedSharding that adds a string description field as metadata for easier debugging.
    """

    def __init__(self, *args, desc: str = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.desc = desc

    def __repr__(self):
        return f"NamedSharding({self.mesh}, {self.spec}, desc={self.desc})"

    def duplicate_with_new_description(self, desc: str):
        """
        Create a new NamedSharding with the same mesh and spec but with a new description.
        """
        return NamedSharding(self.mesh, self.spec, desc=desc)
Phuong Nguyen's avatar
Phuong Nguyen committed
285
286
287
288
289
290
291
292


@functools.lru_cache(maxsize=1)
def is_all_reduce_in_float32():
    """
    Check if all-reduce is in float32
    """
    return os.getenv("NVTE_JAX_ALL_REDUCE_IN_FP32", "0") == "1"