"tests/jax/test_recipe_characteristics.py" did not exist on "df6f347fa1039125f9777400c3e9ce4c461d9eda"
dense.py 19.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
11

12
13
14
15
16
17
from typing import Tuple, Sequence
from functools import partial
import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
18
from .quantize import (
19
20
21
    ScaledTensorFactory,
    ScalingMode,
    QuantizeLayout,
22
23
24
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
25
    is_fp8_gemm_with_all_layouts_supported,
26
    TensorUsage,
27
    get_quantize_config,
28
)
29

Alp Dener's avatar
Alp Dener committed
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def _all_gather_kernel(kernel, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(kernel.shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel_shape = kernel.shape
    kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :])
    global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx)
    global_kernel = global_kernel.reshape(*kernel_whole_shape)
    return global_kernel


def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(scattered_kernel_shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel = kernel.reshape(
        *scattered_kernel_shape[:axis_idx],
        -1,
        scattered_kernel_shape[axis_idx],
        *scattered_kernel_shape[axis_idx + 1 :],
    )
    kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx)
    kernel = kernel.reshape(scattered_kernel_shape)
    return kernel


59
60
61
62
63
def dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    bias: jnp.ndarray = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
64
65
    input_axes: Tuple[str, ...] = None,
    kernel_axes: Tuple[str, ...] = None,
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    quantizer_set: QuantizerSet = noop_quantizer_set,
):
    """Perform dense layer transformation with optional quantization.

    This function implements matrix multiplication with optional bias addition,
    supporting quantization and custom contracting dimensions. It's optimized
    for transformer architectures and supports automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix for the dense layer transformation
        bias: Optional bias tensor to add after the transformation
        contracting_dims: Tuple of sequences specifying which dimensions to contract
        quantizer_set: QuantizerSet which contains quantizers for different tensor types

    Returns:
        Transformed output tensor
    """
84
85
86
87
88
89
90
91
92
93
94
95
96
    if not get_quantize_config().is_fp8_enabled():
        input_dtype = x.dtype
        kernel = kernel.astype(input_dtype)

    output = _dense(
        x,
        kernel,
        bias,
        contracting_dims,
        input_axes,
        kernel_axes,
        quantizer_set,
    )
97
98
99
    return output


100
101
102
103
104
105
106
107
@partial(
    jax.custom_vjp,
    nondiff_argnums=(
        3,
        4,
        5,
    ),
)
108
109
110
111
112
113
114
115
116
def _dense(
    x,
    kernel,
    bias,
    contracting_dims,
    input_axes,
    kernel_axes,
    quantizer_set,
):
117
118
119
120
121
122
123
124
125
126
    """Internal implementation of dense layer transformation with custom VJP.

    This function implements the core dense layer transformation logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix
        bias: Optional bias tensor
        contracting_dims: Contracting dimensions specification
127
128
        input_axes: Logical axes for sharding the activation input
        kernel_axes: Logical axes for sharding the weight matrix
129
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
130
131
132
133

    Returns:
        Transformed output tensor
    """
134
    output, _ = _dense_fwd_rule(
135
136
137
138
139
140
141
        x,
        kernel,
        bias,
        contracting_dims,
        input_axes,
        kernel_axes,
        quantizer_set,
142
    )
143
144
145
    return output


Alp Dener's avatar
Alp Dener committed
146
def _dense_fwd_rule(
147
148
149
150
151
152
153
    x,
    kernel,
    bias,
    contracting_dims,
    input_axes,
    kernel_axes,
    quantizer_set,
Alp Dener's avatar
Alp Dener committed
154
):
155
156
157
158
159
    """Forward pass rule for dense layer transformation.

    Returns:
        Tuple of (output, context) for backward pass
    """
Alp Dener's avatar
Alp Dener committed
160
161
162
163
164
165
166
167
168
169
170
    x_contracting_dims, k_contracting_dims = map(
        tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
    )

    # Check supported input layout
    x_is_transposed = x.ndim - 1 not in x_contracting_dims
    k_is_transposed = kernel.ndim - 1 in k_contracting_dims
    assert (
        not x_is_transposed and not k_is_transposed
    ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."

171
172
173
    flatten_axis_x = -len(x_contracting_dims)
    flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)

Alp Dener's avatar
Alp Dener committed
174
    casted_x = tex.quantize(
175
176
177
        x,
        flatten_axis=flatten_axis_x,
        quantizer=quantizer_set.x,
Alp Dener's avatar
Alp Dener committed
178
    )
179
180
181
    casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)

    casted_kernel = tex.quantize(
Alp Dener's avatar
Alp Dener committed
182
183
184
        kernel,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.kernel,
185
186
    )
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
187
188

    # GEMM NN
Alp Dener's avatar
Alp Dener committed
189
    use_bias = bias is not None
190
    output = tex.gemm(
191
192
        casted_x.get_tensor(usage=TensorUsage.LHS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
193
194
195
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        bias=bias if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
196
    )
197

Alp Dener's avatar
Alp Dener committed
198
    if use_bias and tex.gemm_uses_jax_dot():
199
200
201
202
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
203
204
        casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
205
206
207
208
        x.shape,
        kernel.shape,
        use_bias,
        quantizer_set,
209
        flatten_axis_k,
210
211
212
213
    )
    return output, ctx


214
def _dense_bwd_rule(
215
    contracting_dims, input_axes, kernel_axes, ctx, grad
216
):  # pylint: disable=unused-argument
217
218
219
220
221
222
    """Backward pass rule for dense layer transformation.

    Returns:
        Tuple of gradients with respect to inputs
    """
    (
223
224
        casted_x_lhs,
        casted_kernel_rhs,
225
226
227
228
        x_shape,
        kernel_shape,
        use_bias,
        quantizer_set,
229
        flatten_axis_k,
230
231
    ) = ctx

Alp Dener's avatar
Alp Dener committed
232
233
234
235
    fwd_x_contracting_dims, fwd_k_contracting_dims = map(
        tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
    )

236
    casted_grad, dbias = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
237
238
239
240
        grad,
        is_dbias=use_bias,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.dgrad,
241
    )
242
243
244

    # GEMM NT
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
245
    g_contracting_dim = tuple(
246
247
248
        range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
    )
    # k_non_contracting_dims
249
    k_contracting_dim = tuple(
250
251
        dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
    )
252

253
    dgrad = tex.gemm(
254
255
        casted_grad.get_tensor(usage=TensorUsage.LHS),
        casted_kernel_rhs,
Alp Dener's avatar
Alp Dener committed
256
        contracting_dims=(g_contracting_dim, k_contracting_dim),
257
    )
258
    dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
259
260
261

    # GEMM TN
    # x_non_contracting_dims
262
    g_contracting_dim = x_contracting_dim = tuple(
263
264
265
266
        range(0, len(x_shape) - len(fwd_x_contracting_dims))
    )

    wgrad = tex.gemm(
267
268
        casted_x_lhs,
        casted_grad.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
269
        contracting_dims=(x_contracting_dim, g_contracting_dim),
270
    )
271
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
272
273
274
275
276
277
278
279

    return dgrad, wgrad, dbias, quantizer_set


_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)


def grouped_dense(
280
281
282
283
284
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    group_sizes: jnp.ndarray,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
    bias: jnp.ndarray = None,
285
    kernel_amax: jnp.ndarray = None,
286
287
288
289
    precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
    preferred_element_type: jnp.dtype = None,
    group_offset: jnp.array = None,
    quantizer_set: QuantizerSet = noop_quantizer_set,
290
    kernel_fsdp_info: Tuple[str, int] = (None, -1),
291
):
292
293
    """
    Perform grouped dense (linear) layer transformation with optional quantization.
294

295
296
297
298
299
300
301
    Args:
        x: Input tensor of shape (M, K)
        kernel: Weight matrix of shape (G, K, N)
        group_sizes: 1D array of shape (G,) specifying the size of each group
        contracting_dims: Tuple of sequences specifying which dimensions to contract
                          (currently only supports ((1,), (1,)))
        bias: Bias tensor of shape (G, N)
302
        kernel_amax: The amax values of weight matrix of shape (G,)
303
304
305
306
        precision: JAX precision for the GEMM operation
        preferred_element_type: Preferred data type for the output tensor
        group_offset: 1D array containing offsets for each group (not yet implemented)
        quantizer_set: Set of quantizers for FP8 quantization of the input and output
307
308
309
310
        kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix
                          represented in the format (str, int). The first element is the
                          FSDP mesh axis, and the second element is the dimension along
                          which the weight is sharded.
311
312
313
314
315
316
317
318
319
320

    Returns:
        A jnp.ndarray containing the result of the grouped linear operation
    """
    output = _grouped_dense(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
321
        kernel_amax,
322
323
324
325
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
326
        kernel_fsdp_info,
327
    )
328
    return output
329
330


331
@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10))
332
333
334
335
336
337
def _grouped_dense(
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
338
    kernel_amax,
339
340
341
342
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
343
    kernel_fsdp_info,
344
345
346
347
348
349
350
):
    output, _ = _grouped_dense_fwd_rule(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
351
        kernel_amax,
352
353
354
355
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
356
        kernel_fsdp_info,
357
    )
358
    return output
359
360
361


def _grouped_dense_fwd_rule(
362
363
364
365
366
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
367
    kernel_amax,
368
369
370
371
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
372
    kernel_fsdp_info,
373
):
374
375
376
    use_bias = bias is not None
    is_noop_quantizer_set = quantizer_set == noop_quantizer_set

377
378
379
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None

380
381
382
383
384
385
    if is_noop_quantizer_set:
        grouped_gemm_x = x
        grouped_gemm_kernel = kernel
        ctx_x = x
        ctx_kernel = kernel
        flatten_axis_k = None
386
387
388

        if kernel_fsdp_enabled:
            kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx)
389
    else:
390
391
        original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        x_contracting_dims, k_contracting_dims = contracting_dims
        flatten_axis_x = -len(x_contracting_dims)
        flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1  # +1 for G axis

        assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
        assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
        # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
            "grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
            "and k_contracting_dims=(1,) for now, "
            f"got {x_contracting_dims=} and {k_contracting_dims=}"
        )

        casted_x = tex.grouped_quantize(
407
408
409
410
            x,
            quantizer_set.x,
            group_sizes,
            flatten_axis=flatten_axis_x,
411
        )
412
413
414
415
416
417
418
419
420
421
422

        ctx_kernel_usage = TensorUsage.RHS_TRANS
        if kernel_fsdp_enabled:
            assert quantizer_set.kernel.scaling_mode in [
                ScalingMode.CURRENT_TENSOR_SCALING,
                ScalingMode.DELAYED_TENSOR_SCALING,
            ]
            # Perform `cast` only
            ctx_kernel_usage = TensorUsage.LHS
            quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE

423
        casted_kernel = tex.grouped_quantize(
424
            kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k
425
426
427
428
429
430
        )
        contracting_dims = (x_contracting_dims, k_contracting_dims)

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
        # rowwise_casted_x.original_shape == (M, K)
        # colwise_casted_kernel.original_shape == (G, N, K)
431
432
        grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
        ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage)

        if kernel_fsdp_enabled:
            ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape)
            global_ctx_kernel_data = _all_gather_kernel(
                ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
            )
            kernel_shape = global_ctx_kernel_data.shape

            ctx_kernel = ScaledTensorFactory.create_1x(
                global_ctx_kernel_data.reshape(-1),
                ctx_kernel.scale_inv,
                ctx_kernel.scaling_mode,
                dq_dtype=ctx_kernel.dq_dtype,
                is_colwise=False,
                data_layout="N",
                flatten_axis=ctx_kernel.flatten_axis,
                group_sizes=ctx_kernel.group_sizes,
                original_shape=kernel_shape,
                group_axis=ctx_kernel.group_axis,
            )

            if is_fp8_gemm_with_all_layouts_supported():
                grouped_gemm_kernel = ctx_kernel
            else:
                grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1)
                grouped_gemm_kernel = ScaledTensorFactory.create_1x(
                    grouped_gemm_kernel_data.reshape(-1),
                    ctx_kernel.scale_inv,
                    ctx_kernel.scaling_mode,
                    dq_dtype=ctx_kernel.dq_dtype,
                    is_colwise=True,
                    data_layout="T",
                    flatten_axis=ctx_kernel.flatten_axis,
                    group_sizes=ctx_kernel.group_sizes,
                    original_shape=kernel_shape,
                    group_axis=ctx_kernel.group_axis,
                )
        else:
            grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)

        # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one.
        # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled.
        quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout
477
478
479
480
481
482
483
484
485
486

    output = tex.grouped_gemm(
        grouped_gemm_x,
        grouped_gemm_kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
487
488
489
    )

    ctx = (
490
491
492
493
494
        group_sizes,
        ctx_x,
        ctx_kernel,
        x.shape,
        kernel.shape,
495
        use_bias,
496
497
498
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
499
    )
500
    return output, ctx
501
502


503
def _grouped_dense_bwd_rule(
504
    contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad
505
506
507
):
    fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims

508
    (
509
510
511
512
513
        group_sizes,
        ctx_x,
        ctx_kernel,
        x_shape,
        kernel_shape,
514
        use_bias,
515
516
517
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
518
519
    ) = ctx

520
521
522
523
    if is_noop_quantizer_set:
        # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
        # g_contracting_dim = (1, )
        # k_contracting_dim = (2, )
524
        g_contracting_dim = tuple(
525
            range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
526
527
        )
        k_contracting_dim = tuple(
528
            dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
529
530
        )
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
531
532
        dgrad_grad = grad
        dgrad_kernel_T = ctx_kernel
533

534
535
        # g_contracting_dim = (0, )
        # x_contracting_dim = (0, )
536
537
538
539
        g_contracting_dim = x_contracting_dim = tuple(
            range(0, len(x_shape) - len(fwd_x_contracting_dims))
        )
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        wgrad_x_T = ctx_x
        wgrad_grad = grad
    else:
        casted_grad = tex.grouped_quantize(
            grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
        )

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
        # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
        # extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (1,)
        k_contracting_dim = (2,)
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
554
        dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
555
556
        dgrad_kernel_T = ctx_kernel

557
        # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
558
559
560
        # after the extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (0,)
561
        x_contracting_dim = (0,)
562
563
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
        wgrad_x_T = ctx_x
564
        wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
565
566
567
568
569
570
571
572
573
574

    dgrad = tex.grouped_gemm(
        dgrad_grad,
        dgrad_kernel_T,
        group_sizes,
        dgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
    )
575

576
577
578
579
580
581
582
583
    wgrad = tex.grouped_gemm(
        wgrad_x_T,
        wgrad_grad,
        group_sizes,
        wgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
584
    )
585
586
587
588
589
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    if kernel_fsdp_mesh_axis is not None:
        wgrad = _psum_scatter_kernel(
            wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
        )
590

591
592
    group_sizes_grad = None
    dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
593
    dkernel_amax = None
594

595
    return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set
596
597
598


_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)