dot.py 7.12 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX te modules"""

6
from typing import List, Tuple, Sequence
7
from functools import partial
8
9
10
import jax
import jax.numpy as jnp

11
from . import cpp_extensions as tex
12
from .fp8 import FP8Helper, FP8MetaPackage
13

14
15
Precision = jax.lax.Precision

16
17
18
19
20

def type_safe_dot_general(
    x,
    kernel,
    fp8_meta_pkg: FP8MetaPackage = None,
21
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
22
23
24
25
26
27
) -> jnp.ndarray:
    """
    Type safe dot_general, including FP8.
    """

    if fp8_meta_pkg is None:
28
        assert x.dtype == kernel.dtype, f"lhs dtype = {x.dtype}, rhs dtype = {kernel.dtype}"
29
30
        return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))

31
32
    amax_list = fp8_meta_pkg.amax_list
    scale_list = fp8_meta_pkg.scale_list
33
34
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE
35
    return _fp8_dot(x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype, contracting_dims)
36
37
38
39
40
41


def quantize(x, q_dtype, scale):
    """
    Quantize with scale.
    """
42
    updated_amax = jnp.max(jnp.abs(x)).astype(scale.dtype)
43
44
45
    dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
    scale = scale.astype(x.dtype)
    clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
46
    return clipped_scaled_x.astype(q_dtype), updated_amax
47
48


49
50
51
52
53
54
55
56
def dequantize(x, dq_dtype, scale_inv):
    """
    Dequantize with scale_inv.
    """
    return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)


# Apply jit to guarantee correctness of FP8 GEMM.
57
@partial(jax.jit, static_argnums=(4, 5, 6))
58
def fp8_dot_impl(
59
60
61
62
63
64
65
66
    q_lhs: jnp.ndarray,
    q_rhs: jnp.ndarray,
    lhs_scale_inv: jnp.ndarray,
    rhs_scale_inv: jnp.ndarray,
    ctype: jnp.dtype,  # computing type
    contracting_dims: Tuple[Sequence[int], Sequence[int]],
    precision: Precision = None,
):
67
    """
68
    FP8 GEMM for XLA pattern match
69
    """
70
71
72
73
74
    dim_nums = (contracting_dims, ((), ()))

    lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
    rhs = dequantize(q_rhs, ctype, rhs_scale_inv)

75
76
77
78
79
80
81
82
    return jax.lax.dot_general(lhs, rhs, dim_nums, precision=precision)


def get_precision_of_fp8_dot(enable_2xACC: bool):
    """
    Get Precision of FP8 DOT.
    """
    return jax.lax.Precision.HIGHEST if enable_2xACC else jax.lax.Precision.DEFAULT
83
84


85
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6))
86
87
88
89
90
91
92
93
94
95
96
97
def _fp8_dot(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    amax_list: List[jnp.ndarray],
    scale_list: List[jnp.ndarray],
    fwd_dtype: jnp.dtype,
    bwd_dtype: jnp.dtype,
    contracting_dims: Tuple[Sequence[int], Sequence[int]],
):
    output, _ = _fp8_dot_fwd_rule(
        x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype, contracting_dims
    )
98
99
100
101
    return output


def _fp8_dot_fwd_rule(
102
103
104
105
106
107
108
109
110
111
112
113
    x,
    kernel,
    amax_list,
    scale_list,
    fwd_dtype,
    bwd_dtype,  # pylint: disable=unused-argument
    contracting_dims,
):

    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair(
        *amax_list, *scale_list
    )
114
115
    amax_list = maybe_fm32_to_fp32(*amax_list)
    scale_list = maybe_fm32_to_fp32(*scale_list)
116

117
    lhs_contracting_dims, rhs_contracting_dims = contracting_dims
118

119
120
    x_shape_suf = x.shape[min(lhs_contracting_dims) :]
    kernel_shape_pre = kernel.shape[: max(rhs_contracting_dims) + 1]
121
    assert x_shape_suf == kernel_shape_pre
122

123
    fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
124
125
126
    scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(
        amax_list, scale_list, fp8_dtype_list
    )
127
    amax_list = FP8MetaPackage.update_amax_list(amax_list)
128

129
130
    x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
    x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
131
132
133
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_x, updated_x_amax = quantize(x, fwd_dtype, x_scale)
134

135
136
    kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX]
    kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
137
138
139
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)
140

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    output = fp8_dot_impl(
        casted_x,
        casted_kernel,
        x_scale_inv,
        kernel_scale_inv,
        x.dtype,
        (lhs_contracting_dims, rhs_contracting_dims),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP),
    )

    ctx = (
        casted_x,
        casted_kernel,
        amax_list,
        scale_list,
        scale_inv_list,
        updated_x_amax,
        updated_kernel_amax,
        x.shape,
        kernel.shape,
        maybe_fp32_to_fm32,
    )
163
    return output, ctx
164
165


166
167
168
def _fp8_dot_bwd_rule(
    fwd_dtype, bwd_dtype, contracting_dims, ctx, grad
):  # pylint: disable=unused-argument
169
170
    lhs_contracting_dims, rhs_contracting_dims = contracting_dims

171
172
173
174
175
176
177
178
179
180
181
182
    (
        casted_x,
        casted_kernel,
        amax_list,
        scale_list,
        scale_inv_list,
        updated_x_amax,
        updated_kernel_amax,
        x_shape,
        kernel_shape,
        maybe_fp32_to_fm32,
    ) = ctx
183

184
185
186
    grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
    grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
    grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
187

188
189
190
191
192
193
194
195
196
    casted_grad, casted_grad_t, updated_grad_amax = tex.cast_transpose(
        grad,
        grad_amax,
        grad_scale,
        grad_scale_inv,
        bwd_dtype,
        static_axis_boundary=-1,
        transpose_axis_boundary=min(lhs_contracting_dims),
    )
197

198
199
    x_constracting_dim = tuple(range(0, len(x_shape) - len(lhs_contracting_dims)))
    gt_constracting_dim = tuple(range(grad.ndim - len(x_constracting_dim), grad.ndim))
200
    x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
201
202
203
204
205
206
207
208
209
    wgrad = fp8_dot_impl(
        casted_x,
        casted_grad_t,
        x_scale_inv,
        grad_scale_inv,
        grad.dtype,
        (x_constracting_dim, gt_constracting_dim),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD),
    )
210

211
    g_constracting_dim = tuple(
212
213
        range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim)
    )
214
    k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
215
    kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
216
217
218
219
220
221
222
223
224
225
226
    dgrad = fp8_dot_impl(
        casted_grad,
        casted_kernel,
        grad_scale_inv,
        kernel_scale_inv,
        grad.dtype,
        (g_constracting_dim, k_constracting_dim),
        get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD),
    )

    amax_list[FP8MetaPackage.INPUT_IDX] = (
227
        amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax)
228
229
    )
    amax_list[FP8MetaPackage.WEIGHT_IDX] = (
230
        amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax)
231
232
    )
    amax_list[FP8MetaPackage.GRAD_IDX] = (
233
        amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
234
    )
235

236
237
    amax_list = maybe_fp32_to_fm32(*amax_list)
    scale_list = maybe_fp32_to_fm32(*scale_list)
238

239
    return dgrad, wgrad, amax_list, scale_list
240
241


242
_fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)