dot.py 6.97 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""JAX te modules"""

6
from typing import List, Tuple, Sequence
7
from functools import partial
8
9
10
import jax
import jax.numpy as jnp

11
from . import cpp_extensions as tex
12
from .fp8 import FP8Helper, FP8MetaPackage
13

14
15
Precision = jax.lax.Precision

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

def type_safe_dot_general(
    x,
    kernel,
    fp8_meta_pkg: FP8MetaPackage = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,))
) -> jnp.ndarray:
    """
    Type safe dot_general, including FP8.
    """

    if fp8_meta_pkg is None:
        kernel = jnp.asarray(kernel, x.dtype)
        return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))

31
32
    amax_list = fp8_meta_pkg.amax_list
    scale_list = fp8_meta_pkg.scale_list
33
34
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE
35
    return _fp8_dot(x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype, contracting_dims)
36
37
38
39
40
41


def quantize(x, q_dtype, scale):
    """
    Quantize with scale.
    """
42
    updated_amax = jnp.max(jnp.abs(x)).astype(scale.dtype)
43
44
45
    dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
    scale = scale.astype(x.dtype)
    clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
46
    return clipped_scaled_x.astype(q_dtype), updated_amax
47
48


49
50
51
52
53
54
55
56
def dequantize(x, dq_dtype, scale_inv):
    """
    Dequantize with scale_inv.
    """
    return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)


# Apply jit to guarantee correctness of FP8 GEMM.
57
@partial(jax.jit, static_argnums=(4, 5, 6))
58
59
60
61
62
63
def fp8_dot_impl(
        q_lhs: jnp.ndarray,
        q_rhs: jnp.ndarray,
        lhs_scale_inv: jnp.ndarray,
        rhs_scale_inv: jnp.ndarray,
        ctype: jnp.dtype,    # computing type
64
65
        contracting_dims: Tuple[Sequence[int], Sequence[int]],
        precision: Precision = None):
66
    """
67
    FP8 GEMM for XLA pattern match
68
    """
69
70
71
72
73
    dim_nums = (contracting_dims, ((), ()))

    lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
    rhs = dequantize(q_rhs, ctype, rhs_scale_inv)

74
75
76
77
78
79
80
81
    return jax.lax.dot_general(lhs, rhs, dim_nums, precision=precision)


def get_precision_of_fp8_dot(enable_2xACC: bool):
    """
    Get Precision of FP8 DOT.
    """
    return jax.lax.Precision.HIGHEST if enable_2xACC else jax.lax.Precision.DEFAULT
82
83


84
85
86
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6))
def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, amax_list: List[jnp.ndarray],
             scale_list: List[jnp.ndarray], fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
87
             contracting_dims: Tuple[Sequence[int], Sequence[int]]):
88
    output, _ = _fp8_dot_fwd_rule(x, kernel, amax_list, scale_list, fwd_dtype, bwd_dtype,
89
90
91
92
93
94
                                  contracting_dims)
    return output


def _fp8_dot_fwd_rule(
        x,
95
        kernel,
96
97
        amax_list,
        scale_list,
98
99
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
100
        contracting_dims):
101
102

    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
103
104
105
        FP8Helper.generate_fp8_meta_dtype_converter_pair(*amax_list, *scale_list)
    amax_list = maybe_fm32_to_fp32(*amax_list)
    scale_list = maybe_fm32_to_fp32(*scale_list)
106

107
    lhs_contracting_dims, rhs_contracting_dims = contracting_dims
108
109

    x_shape_suf = x.shape[min(lhs_contracting_dims):]
110
    kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
111
    assert x_shape_suf == kernel_shape_pre
112

113
114
115
116
    fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype]
    scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale(amax_list, scale_list,
                                                                 fp8_dtype_list)
    amax_list = FP8MetaPackage.update_amax_list(amax_list)
117

118
119
    x_scale = scale_list[FP8MetaPackage.INPUT_IDX]
    x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
120
121
122
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_x, updated_x_amax = quantize(x, fwd_dtype, x_scale)
123

124
125
    kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX]
    kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
126
127
128
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)
129

130
131
132
    output = fp8_dot_impl(casted_x, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
                          (lhs_contracting_dims, rhs_contracting_dims),
                          get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
133

134
    ctx = (casted_x, casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax,
135
           updated_kernel_amax, x.shape, kernel.shape, maybe_fp32_to_fm32)
136
    return output, ctx
137
138


139
140
141
def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad):    # pylint: disable=unused-argument
    lhs_contracting_dims, rhs_contracting_dims = contracting_dims

142
    casted_x, casted_kernel, amax_list, scale_list, scale_inv_list, \
143
144
        updated_x_amax, updated_kernel_amax, x_shape, kernel_shape, \
        maybe_fp32_to_fm32 = ctx
145

146
147
148
    grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
    grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
    grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
149

150
    casted_grad, casted_grad_t, updated_grad_amax = \
151
        tex.cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
152
153
                       bwd_dtype, static_axis_boundary=-1,
                       transpose_axis_boundary=min(lhs_contracting_dims))
154

155
156
    x_constracting_dim = tuple(range(0, len(x_shape) - len(lhs_contracting_dims)))
    gt_constracting_dim = tuple(range(grad.ndim - len(x_constracting_dim), grad.ndim))
157
    x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX]
158
159
160
    wgrad = fp8_dot_impl(casted_x, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
                         (x_constracting_dim, gt_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
161

162
163
164
    g_constracting_dim = tuple(
        range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
    k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
165
    kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX]
166
167
168
    dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
                         (g_constracting_dim, k_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
169

170
171
172
173
174
175
    amax_list[FP8MetaPackage.INPUT_IDX] = \
        amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax)
    amax_list[FP8MetaPackage.WEIGHT_IDX] = \
        amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax)
    amax_list[FP8MetaPackage.GRAD_IDX] = \
        amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0])
176

177
178
    amax_list = maybe_fp32_to_fm32(*amax_list)
    scale_list = maybe_fp32_to_fm32(*scale_list)
179

180
    return dgrad, wgrad, amax_list, scale_list
181
182


183
_fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)