dot.py 6.09 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""JAX te modules"""

from typing import Tuple, Sequence
7
from functools import partial
8
9
10
import jax
import jax.numpy as jnp

11
12
from .cpp_extensions import cast_transpose
from .fp8 import FP8Helper, FP8MetaPackage
13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

def type_safe_dot_general(
    x,
    kernel,
    fp8_meta_pkg: FP8MetaPackage = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,))
) -> jnp.ndarray:
    """
    Type safe dot_general, including FP8.
    """

    if fp8_meta_pkg is None:
        kernel = jnp.asarray(kernel, x.dtype)
        return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))

    fp8_max = fp8_meta_pkg.fp8_max
    amax = fp8_meta_pkg.amax
    scale = fp8_meta_pkg.scale
    scale_inv = fp8_meta_pkg.scale_inv
    fwd_dtype = FP8Helper.FWD_DTYPE
    bwd_dtype = FP8Helper.BWD_DTYPE
    return _fp8_dot(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
                    contracting_dims)


def quantize(x, q_dtype, scale):
    """
    Quantize with scale.
    """
    dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype)
    scale = scale.astype(x.dtype)
    clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max)
    return clipped_scaled_x.astype(q_dtype)
47
48


49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def dequantize(x, dq_dtype, scale_inv):
    """
    Dequantize with scale_inv.
    """
    return x.astype(dq_dtype) * scale_inv.astype(dq_dtype)


# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(4, 5))
def fp8_dot_impl(
        q_lhs: jnp.ndarray,
        q_rhs: jnp.ndarray,
        lhs_scale_inv: jnp.ndarray,
        rhs_scale_inv: jnp.ndarray,
        ctype: jnp.dtype,    # computing type
        contracting_dims: Tuple[Sequence[int], Sequence[int]]):
65
    """
66
    FP8 GEMM for XLA pattern match
67
    """
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    dim_nums = (contracting_dims, ((), ()))

    lhs = dequantize(q_lhs, ctype, lhs_scale_inv)
    rhs = dequantize(q_rhs, ctype, rhs_scale_inv)

    return jax.lax.dot_general(lhs, rhs, dim_nums)


@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8))
def _fp8_dot(x: jnp.ndarray, kernel: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray,
             scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype,
             contracting_dims: Tuple[Sequence[int], Sequence[int]]):
    output, _ = _fp8_dot_fwd_rule(x, kernel, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype,
                                  contracting_dims)
    return output


def _fp8_dot_fwd_rule(
        x,
87
        kernel,
88
        fp8_max,
89
90
91
92
93
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
94
        contracting_dims):
95
    lhs_contracting_dims, rhs_contracting_dims = contracting_dims
96
97

    x_shape_suf = x.shape[min(lhs_contracting_dims):]
98
    kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
99
    assert x_shape_suf == kernel_shape_pre
100

101
102
    amax = FP8Helper.update_amax_history(amax)

103
    gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
104

105
106
107
108
109
110
111
    x_amax = amax[gemm_x_idx, 0:1]
    x_scale = scale[gemm_x_idx]
    x_scale_inv = scale_inv[gemm_x_idx]

    casted_x, casted_xt, updated_x_amax = \
        cast_transpose(x, x_amax, x_scale, x_scale_inv, fwd_dtype, static_axis_boundary=-1,
                       transpose_axis_boundary=min(lhs_contracting_dims))
112

113
    kernel_amax = amax[gemm_kernel_idx, 0:1]
114
115
116
    kernel_scale = scale[gemm_kernel_idx]
    kernel_scale_inv = scale_inv[gemm_kernel_idx]

117
118
119
120
    casted_kerenl, casted_kerenl_t, updated_kernel_amax = \
        cast_transpose(kernel, kernel_amax, kernel_scale, kernel_scale_inv,
                       fwd_dtype, static_axis_boundary=-1,
                       transpose_axis_boundary=(max(rhs_contracting_dims) + 1))
121

122
123
124
    rhs_t_contracting_dims = tuple(range(kernel.ndim - len(rhs_contracting_dims), kernel.ndim))
    output = fp8_dot_impl(casted_x, casted_kerenl_t, x_scale_inv, kernel_scale_inv, x.dtype,
                          (lhs_contracting_dims, rhs_t_contracting_dims))
125

126
127
128
    ctx = (casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, updated_x_amax,
           updated_kernel_amax, x.shape, kernel.shape)
    return output, ctx
129
130


131
132
133
134
135
136
137
def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad):    # pylint: disable=unused-argument
    lhs_contracting_dims, rhs_contracting_dims = contracting_dims

    casted_xt, casted_kerenl, fp8_max, amax, scale, scale_inv, \
        updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx

    gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
138

139
    grad_amax = amax[gemm_grad_idx, 0:1]
140
141
142
    grad_scale = scale[gemm_grad_idx]
    grad_scale_inv = scale_inv[gemm_grad_idx]

143
144
145
146
    casted_grad, casted_grad_t, updated_grad_amax = \
        cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
                       bwd_dtype, static_axis_boundary=-1,
                       transpose_axis_boundary=min(lhs_contracting_dims))
147

148
149
150
151
152
    xt_constracting_dim = tuple(range(len(lhs_contracting_dims), len(x_shape)))
    gt_constracting_dim = tuple(range(grad.ndim - len(xt_constracting_dim), grad.ndim))
    x_scale_inv = scale_inv[gemm_x_idx]
    wgrad = fp8_dot_impl(casted_xt, casted_grad_t, x_scale_inv, grad_scale_inv, grad.dtype,
                         (xt_constracting_dim, gt_constracting_dim))
153

154
155
156
157
158
159
    g_constracting_dim = tuple(
        range(grad.ndim - len(kernel_shape) + len(rhs_contracting_dims), grad.ndim))
    k_constracting_dim = tuple(range(len(rhs_contracting_dims), len(kernel_shape)))
    kernel_scale_inv = scale_inv[gemm_kernel_idx]
    dgrad = fp8_dot_impl(casted_grad, casted_kerenl, grad_scale_inv, kernel_scale_inv, grad.dtype,
                         (g_constracting_dim, k_constracting_dim))
160

161
162
163
    amax = amax.at[gemm_x_idx, 0].set(updated_x_amax[0])
    amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0])
    amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0])
164

165
    scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale)
166

167
    return dgrad, wgrad, fp8_max, amax, scale, scale_inv
168
169


170
_fp8_dot.defvjp(_fp8_dot_fwd_rule, _fp8_dot_bwd_rule)