dot.py 6.79 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""JAX te modules"""

from typing import Tuple, Sequence
7
from functools import partial
8
9
10
import jax
import jax.numpy as jnp

11
12
from .cpp_extensions import cast_transpose
from .fp8 import FP8Helper, FP8MetaPackage
13

14
15
Precision = jax.lax.Precision

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

def type_safe_dot_general(
    x,
    kernel,
    fp8_meta_pkg: FP8MetaPackage = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,))
) -> jnp.ndarray:
    """
    Type safe dot_general, including FP8.
    """

    if fp8_meta_pkg is None:
        kernel = jnp.asarray(kernel, x.dtype)
        return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))

    fp8_max = fp8_meta_pkg.fp8_max
    amax = fp8_meta_pkg.amax
    scale = fp8_meta_pkg.scale
    scale_inv = fp8_meta_pkg.scale_inv
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE
    return _fp8_dot(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
                    contracting_dims)


def quantize(x, q_dtype, scale):
    """
    Quantize with scale.
    """
45
    updated_amax = jnp.max(jnp.abs(x)).astype(scale.dtype)
46
47
48
    dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
    scale = scale.astype(x.dtype)
    clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
49
    return clipped_scaled_x.astype(q_dtype), updated_amax
50
51


52
53
54
55
56
57
58
59
def dequantize(x, dq_dtype, scale_inv):
    """
    Dequantize with scale_inv.
    """
    return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)


# Apply jit to guarantee correctness of FP8 GEMM.
60
@partial(jax.jit, static_argnums=(4, 5, 6))
61
62
63
64
65
66
def fp8_dot_impl(
        q_lhs: jnp.ndarray,
        q_rhs: jnp.ndarray,
        lhs_scale_inv: jnp.ndarray,
        rhs_scale_inv: jnp.ndarray,
        ctype: jnp.dtype,    # computing type
67
68
        contracting_dims: Tuple[Sequence[int], Sequence[int]],
        precision: Precision = None):
69
    """
70
    FP8 GEMM for XLA pattern match
71
    """
72
73
74
75
76
    dim_nums = (contracting_dims, ((), ()))

    lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
    rhs = dequantize(q_rhs, ctype, rhs_scale_inv)

77
78
79
80
81
82
83
84
    return jax.lax.dot_general(lhs, rhs, dim_nums, precision=precision)


def get_precision_of_fp8_dot(enable_2xACC: bool):
    """
    Get Precision of FP8 DOT.
    """
    return jax.lax.Precision.HIGHEST if enable_2xACC else jax.lax.Precision.DEFAULT
85
86
87
88
89
90
91
92
93
94
95
96
97


@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
             scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
             contracting_dims: Tuple[Sequence[int], Sequence[int]]):
    output, _ = _fp8_dot_fwd_rule(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
                                  contracting_dims)
    return output


def _fp8_dot_fwd_rule(
        x,
98
        kernel,
99
        fp8_max,
100
101
102
103
104
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
105
        contracting_dims):
106
107
108
109
110

    maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \
        FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv)
    fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv)

111
    lhs_contracting_dims, rhs_contracting_dims = contracting_dims
112
113

    x_shape_suf = x.shape[min(lhs_contracting_dims):]
114
    kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
115
    assert x_shape_suf == kernel_shape_pre
116

117
    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
118
119
    amax = FP8Helper.update_amax_history(amax)

120
    gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
121

122
123
    x_scale = scale[gemm_x_idx]
    x_scale_inv = scale_inv[gemm_x_idx]
124
125
126
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_x, updated_x_amax = quantize(x, fwd_dtype, x_scale)
127

128
129
    kernel_scale = scale[gemm_kernel_idx]
    kernel_scale_inv = scale_inv[gemm_kernel_idx]
130
131
132
    # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
    # unnecessary copy to break FP8 GEMM pattern matching.
    casted_kernel, updated_kernel_amax = quantize(kernel, fwd_dtype, kernel_scale)
133

134
135
136
    output = fp8_dot_impl(casted_x, casted_kernel, x_scale_inv, kernel_scale_inv, x.dtype,
                          (lhs_contracting_dims, rhs_contracting_dims),
                          get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP))
137

138
    ctx = (casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax,
139
           updated_kernel_amax, x.shape, kernel.shape, maybe_fp32_to_fm32)
140
    return output, ctx
141
142


143
144
145
def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad):    # pylint: disable=unused-argument
    lhs_contracting_dims, rhs_contracting_dims = contracting_dims

146
    casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, \
147
148
        updated_x_amax, updated_kernel_amax, x_shape, kernel_shape, \
        maybe_fp32_to_fm32 = ctx
149
150

    gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
151

152
    grad_amax = amax[gemm_grad_idx, 0:1]
153
154
155
    grad_scale = scale[gemm_grad_idx]
    grad_scale_inv = scale_inv[gemm_grad_idx]

156
157
158
159
    casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
                       bwd_dtype, static_axis_boundary=-1,
                       transpose_axis_boundary=min(lhs_contracting_dims))
160

161
162
    x_constracting_dim = tuple(range(0, len(x_shape) - len(lhs_contracting_dims)))
    gt_constracting_dim = tuple(range(grad.ndim - len(x_constracting_dim), grad.ndim))
163
    x_scale_inv = scale_inv[gemm_x_idx]
164
165
166
    wgrad = fp8_dot_impl(casted_x, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
                         (x_constracting_dim, gt_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD))
167

168
169
170
171
    g_constracting_dim = tuple(
        range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
    k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
    kernel_scale_inv = scale_inv[gemm_kernel_idx]
172
173
174
    dgrad = fp8_dot_impl(casted_grad, casted_kernel, grad_scale_inv, kernel_scale_inv, grad.dtype,
                         (g_constracting_dim, k_constracting_dim),
                         get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD))
175

176
177
    amax = amax.at[gemm_x_idx, 0].set(updated_x_amax)
    amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax)
178
    amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
179

180
    fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv)
181

182
    return dgrad, wgrad, fp8_max, amax, scale, scale_inv
183
184


185
_fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)