dot.py 9.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""

from typing import Tuple, Sequence
from functools import partial, reduce
import operator
import jax
import jax.numpy as jnp

from transformer_engine_jax import DType as TEDType
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
17
from .sharding import xmap_runner, extend_fsdp_sharding_meta
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)


def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
            fwd_dtype: TEDType,
            bwd_dtype: TEDType,
            contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
            sharding_type: ShardingType = ShardingType.SINGLE,
            dp_dim_index: int = 0) -> jnp.ndarray:
    """
    FP8 dot wrapper
    """
    assert fp8_gemm_pkg.num_of_gemm == 1
    inputs = fp8_gemm_pkg.inputs
    kernel = fp8_gemm_pkg.kernels[0]
    fp8_max = fp8_gemm_pkg.fp8_max
    amax = fp8_gemm_pkg.amax
    scale = fp8_gemm_pkg.scale
    scale_inv = fp8_gemm_pkg.scale_inv

    if sharding_type is ShardingType.SINGLE:
        res = _fp8_dot(inputs,
                       kernel,
                       fp8_max,
                       amax,
                       scale,
                       scale_inv,
                       fwd_dtype=fwd_dtype,
                       bwd_dtype=bwd_dtype,
                       contracting_dims=contracting_dims,
                       sharding_type=sharding_type,
                       dp_axis_name="",
52
53
                       tp_axis_name="",
                       fsdp_axis_name="")
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    else:
        dp_axis_name = "batch"
        tp_axis_name = "model"
        kernel_tp_index = None
        # TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
        if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
            kernel_tp_index = len(kernel.shape) - 1
        elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
            kernel_tp_index = 0

        input_tp_index = len(inputs.shape) - 1
        sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
                                              dp_dim_index, input_tp_index, kernel_tp_index,
                                              contracting_dims, dp_axis_name, tp_axis_name)
68
        sharding_meta, fsdp_axis_name = extend_fsdp_sharding_meta(sharding_meta, {0: dp_dim_index})
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0])    # 0 for input
        kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1])    # 1 for kernel

        num_of_fp8_meta_kind = 4    # fp8_max, amax, scale, scale_inv
        fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
                                                       dp_axis_name, tp_axis_name)

        axis_resources = merge_axis_resources(
            [sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])

        partial_fp8_dot = partial(_fp8_dot,
                                  fwd_dtype=fwd_dtype,
                                  bwd_dtype=bwd_dtype,
                                  contracting_dims=contracting_dims,
                                  sharding_type=sharding_type,
                                  dp_axis_name=dp_axis_name,
85
86
                                  tp_axis_name=tp_axis_name,
                                  fsdp_axis_name=fsdp_axis_name)
87
88
89
90
91
92
93
94
95
        res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes),
                          sharding_meta.out_axes, axis_resources,
                          (inputs_, kernel_, fp8_max, amax, scale, scale_inv))

        res = jnp.reshape(res, sharding_meta.output_shapes[0])

    return res


96
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
97
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
98
99
             scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
             contracting_dims: Tuple[Sequence[int], Sequence[int]], sharding_type: ShardingType,
100
             dp_axis_name: str, tp_axis_name: str, fsdp_axis_name: str):
101
102
103
104
105
106
107
108
109
110
111
    res, _ = _fp8_dot_fwd(inputs,
                          kernel,
                          fp8_maxs,
                          amax,
                          scale,
                          scale_inv,
                          fwd_dtype,
                          bwd_dtype,
                          contracting_dims=contracting_dims,
                          sharding_type=sharding_type,
                          dp_axis_name=dp_axis_name,
112
113
                          tp_axis_name=tp_axis_name,
                          fsdp_axis_name=fsdp_axis_name)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    return res


def _fp8_dot_fwd(
        inputs,
        kernel,
        fp8_maxs,
        amax,
        scale,
        scale_inv,
        fwd_dtype,
        bwd_dtype,    # pylint: disable=unused-argument
        contracting_dims,
        sharding_type,
        dp_axis_name,    # pylint: disable=unused-argument
129
130
        tp_axis_name,
        fsdp_axis_name):    # pylint: disable=unused-argument
131
132
133
134
135
136
137
138
139
140
141
    lhs_contracting_dims, rhs_contracting_dims = contracting_dims
    input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
    input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
    kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
    kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
    input_contracting_size = reduce(operator.mul, input_shape_suf)
    kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
    assert input_contracting_size == kernel_contracting_size
    inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
    kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))

142
143
    amax = FP8Helper.update_amax_history(amax)

144
145
    gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)

146
    input_amax = amax[gemm_input_idx, 0:1]
147
148
149
150
151
    input_scale = scale[gemm_input_idx]
    input_scale_inv = scale_inv[gemm_input_idx]
    input_cast, input_cast_trans, input_amax = cast_transpose(inputs_, input_amax, input_scale,
                                                              input_scale_inv, fwd_dtype)

152
    kernel_amax = amax[gemm_kernel_idx, 0:1]
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    kernel_scale = scale[gemm_kernel_idx]
    kernel_scale_inv = scale_inv[gemm_kernel_idx]
    kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
                                                                 kernel_scale_inv, fwd_dtype)
    res = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, input_cast, input_scale_inv,
               fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)

    if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
        res = jax.lax.psum(res, tp_axis_name)

    # (input_shape_pre, input_shape_suf)
    # x (kernel_shape_pre, kernel_shape_suf)
    # = (input_shape_pre, kernel_shape_suf)
    output_shape = input_shape_pre + kernel_shape_suf
    res = jnp.reshape(res, output_shape)

    ctx = (input_cast_trans, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
           inputs.shape, kernel.shape)
    return res, ctx


def _fp8_dot_bwd(
        fwd_dtype,
        bwd_dtype,
        contracting_dims,    # pylint: disable=unused-argument
        sharding_type,
        dp_axis_name,
        tp_axis_name,
181
        fsdp_axis_name,
182
183
184
185
186
187
188
189
190
        ctx,
        g):
    input_cast_trans, kernel_cast, \
    fp8_maxs, amax, scale, scale_inv, \
    input_amax, kernel_amax, \
    inputs_shape, kernel_shape = ctx

    gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)

191
    grad_amax = amax[gemm_grad_idx, 0:1]
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    grad_scale = scale[gemm_grad_idx]
    grad_scale_inv = scale_inv[gemm_grad_idx]
    g = jnp.reshape(g, (input_cast_trans.shape[1], -1))
    grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
                                                           bwd_dtype)

    input_scale_inv = scale_inv[gemm_input_idx]
    wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype,
                 True, input_cast_trans, input_scale_inv, fwd_dtype, False,
                 jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)

    kernel_scale_inv = scale_inv[gemm_kernel_idx]
    dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
                 bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)

207
208
209
    amax = amax.at[gemm_input_idx, 0].set(input_amax[0])
    amax = amax.at[gemm_kernel_idx, 0].set(kernel_amax[0])
    amax = amax.at[gemm_grad_idx, 0].set(grad_amax[0])
210
211
212
213
214

    if is_dp_enabled(sharding_type.value[0]):
        wgrad = jax.lax.psum(wgrad, dp_axis_name)
        amax = jax.lax.pmax(amax, dp_axis_name)

215
216
217
218
    if len(fsdp_axis_name) > 0:
        wgrad = jax.lax.psum(wgrad, fsdp_axis_name)
        amax = jax.lax.pmax(amax, fsdp_axis_name)

219
220
221
222
223
224
225
226
227
228
229
230
    if is_tp_enabled(sharding_type.value[0]):
        amax = jax.lax.pmax(amax, tp_axis_name)

    if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
        dgrad = jax.lax.psum(dgrad, tp_axis_name)

    dgrad = jnp.reshape(dgrad, inputs_shape)
    wgrad = jnp.reshape(wgrad, kernel_shape)
    return dgrad, wgrad, fp8_maxs, amax, scale, scale_inv


_fp8_dot.defvjp(_fp8_dot_fwd, _fp8_dot_bwd)