dense.py 21.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
11

12
13
from typing import Tuple, Sequence
from functools import partial
Phuong Nguyen's avatar
Phuong Nguyen committed
14
import warnings
15
16
17
18
import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
19
from .cpp_extensions.amax import AmaxScope
20
from .quantize import (
21
    ScaledTensorFactory,
22
    ScaledTensor,
23
24
    ScalingMode,
    QuantizeLayout,
25
26
27
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
28
    is_fp8_gemm_with_all_layouts_supported,
29
    TensorUsage,
30
    get_quantize_config,
31
)
32

Alp Dener's avatar
Alp Dener committed
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def _all_gather_kernel(kernel, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(kernel.shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel_shape = kernel.shape
    kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :])
    global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx)
    global_kernel = global_kernel.reshape(*kernel_whole_shape)
    return global_kernel


def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(scattered_kernel_shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel = kernel.reshape(
        *scattered_kernel_shape[:axis_idx],
        -1,
        scattered_kernel_shape[axis_idx],
        *scattered_kernel_shape[axis_idx + 1 :],
    )
    kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx)
    kernel = kernel.reshape(scattered_kernel_shape)
    return kernel


62
63
64
65
66
def dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    bias: jnp.ndarray = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
67
    transpose_batch_sequence: bool = False,
68
69
    input_axes: Tuple[str, ...] = None,
    kernel_axes: Tuple[str, ...] = None,
Phuong Nguyen's avatar
Phuong Nguyen committed
70
71
72
    output_axes: Tuple[str, ...] = None,
    collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
    quantizer_set: QuantizerSet = noop_quantizer_set,
73
74
75
76
77
78
79
80
81
82
83
84
):
    """Perform dense layer transformation with optional quantization.

    This function implements matrix multiplication with optional bias addition,
    supporting quantization and custom contracting dimensions. It's optimized
    for transformer architectures and supports automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix for the dense layer transformation
        bias: Optional bias tensor to add after the transformation
        contracting_dims: Tuple of sequences specifying which dimensions to contract
85
        transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
Phuong Nguyen's avatar
Phuong Nguyen committed
86
87
88
89
90
        input_axes: Logical axes for sharding the activation input
        kernel_axes: Logical axes for sharding the weight matrix
        output_axes: Logical axes for sharding the output
        collective_op_set: A set of CollectiveOp objects for forward and backward passes.
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
91
92
93
94

    Returns:
        Transformed output tensor
    """
95
96
    if transpose_batch_sequence:
        warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
Phuong Nguyen's avatar
Phuong Nguyen committed
97

98
99
100
101
102
103
104
105
106
    if not get_quantize_config().is_fp8_enabled():
        input_dtype = x.dtype
        kernel = kernel.astype(input_dtype)

    output = _dense(
        x,
        kernel,
        bias,
        contracting_dims,
107
        transpose_batch_sequence,
108
109
        input_axes,
        kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
110
111
112
        output_axes,
        collective_op_set,
        quantizer_set,
113
    )
114
115
116
    return output


117
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
118
119
120
121
122
def _dense(
    x,
    kernel,
    bias,
    contracting_dims,
123
    transpose_batch_sequence,
124
125
    input_axes,
    kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
126
127
128
    output_axes,
    collective_op_set,
    quantizer_set,  # need to be a diff_arg for DelayedScaling state management
129
):
130
131
132
133
134
135
136
137
138
139
    """Internal implementation of dense layer transformation with custom VJP.

    This function implements the core dense layer transformation logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix
        bias: Optional bias tensor
        contracting_dims: Contracting dimensions specification
140
        transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
141
        input_axes: Logical axes for sharding the activation input
Phuong Nguyen's avatar
Phuong Nguyen committed
142
        output_axes: Logical axes for sharding the output_axes
143
        kernel_axes: Logical axes for sharding the weight matrix
Phuong Nguyen's avatar
Phuong Nguyen committed
144
145
        collective_op_set: A set of CollectiveOp objects for forward and backward passes.
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
146
147
148
149

    Returns:
        Transformed output tensor
    """
150
    output, _ = _dense_fwd_rule(
151
152
153
154
        x,
        kernel,
        bias,
        contracting_dims,
155
        transpose_batch_sequence,
156
157
        input_axes,
        kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
158
159
160
        output_axes,
        collective_op_set,
        quantizer_set,
161
    )
162
163
164
    return output


Alp Dener's avatar
Alp Dener committed
165
def _dense_fwd_rule(
166
167
168
169
    x,
    kernel,
    bias,
    contracting_dims,
170
    transpose_batch_sequence,
171
172
    input_axes,
    kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
173
174
175
    output_axes,
    collective_op_set,
    quantizer_set,
Alp Dener's avatar
Alp Dener committed
176
):
177
178
179
180
181
    """Forward pass rule for dense layer transformation.

    Returns:
        Tuple of (output, context) for backward pass
    """
Alp Dener's avatar
Alp Dener committed
182
183
184
185
186
187
188
189
190
191
192
    x_contracting_dims, k_contracting_dims = map(
        tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
    )

    # Check supported input layout
    x_is_transposed = x.ndim - 1 not in x_contracting_dims
    k_is_transposed = kernel.ndim - 1 in k_contracting_dims
    assert (
        not x_is_transposed and not k_is_transposed
    ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."

193
194
195
    flatten_axis_x = -len(x_contracting_dims)
    flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)

Alp Dener's avatar
Alp Dener committed
196
    casted_x = tex.quantize(
197
198
199
        x,
        flatten_axis=flatten_axis_x,
        quantizer=quantizer_set.x,
200
        amax_scope=AmaxScope.TPSP,
201
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
202
    )
203
204
205
    casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)

    casted_kernel = tex.quantize(
Alp Dener's avatar
Alp Dener committed
206
207
208
        kernel,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.kernel,
209
        amax_scope=AmaxScope.FSDP,
210
211
    )
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
212
213

    # GEMM NN
Alp Dener's avatar
Alp Dener committed
214
    use_bias = bias is not None
215
    output = tex.gemm(
216
217
        casted_x.get_tensor(usage=TensorUsage.LHS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
218
        contracting_dims=(x_contracting_dims, k_contracting_dims),
219
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
220
221
        bias=bias if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
222
        collective_op=collective_op_set.forward,
223
    )
Phuong Nguyen's avatar
Phuong Nguyen committed
224
    output = with_sharding_constraint_by_logical_axes(output, output_axes)
225

Alp Dener's avatar
Alp Dener committed
226
    if use_bias and tex.gemm_uses_jax_dot():
227
228
229
230
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
231
232
        casted_x.get_tensor(usage=TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
        casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
233
234
235
236
        x.shape,
        kernel.shape,
        use_bias,
        quantizer_set,
237
        flatten_axis_k,
238
239
240
241
    )
    return output, ctx


242
def _dense_bwd_rule(
Phuong Nguyen's avatar
Phuong Nguyen committed
243
    contracting_dims,
244
    transpose_batch_sequence,
Phuong Nguyen's avatar
Phuong Nguyen committed
245
246
247
248
249
250
251
    input_axes,
    kernel_axes,
    output_axes,
    collective_op_set,
    ctx,
    grad,
):
252
253
254
255
256
257
    """Backward pass rule for dense layer transformation.

    Returns:
        Tuple of gradients with respect to inputs
    """
    (
258
259
        casted_x_lhs,
        casted_kernel_rhs,
260
261
262
263
        x_shape,
        kernel_shape,
        use_bias,
        quantizer_set,
264
        flatten_axis_k,
265
    ) = ctx
Phuong Nguyen's avatar
Phuong Nguyen committed
266
    grad = with_sharding_constraint_by_logical_axes(grad, output_axes)
267

Alp Dener's avatar
Alp Dener committed
268
269
270
271
    fwd_x_contracting_dims, fwd_k_contracting_dims = map(
        tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
    )

272
    casted_grad, dbias = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
273
274
275
276
        grad,
        is_dbias=use_bias,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.dgrad,
277
        amax_scope=AmaxScope.TPSP,
278
        transpose_batch_sequence=transpose_batch_sequence,
279
    )
280
281
282

    # GEMM NT
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
283
    g_contracting_dim = tuple(
284
285
286
        range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
    )
    # k_non_contracting_dims
287
    k_contracting_dim = tuple(
288
289
        dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
    )
290

291
    dgrad = tex.gemm(
292
293
        casted_grad.get_tensor(usage=TensorUsage.LHS),
        casted_kernel_rhs,
Alp Dener's avatar
Alp Dener committed
294
        contracting_dims=(g_contracting_dim, k_contracting_dim),
295
        transpose_batch_sequence=transpose_batch_sequence,
Phuong Nguyen's avatar
Phuong Nguyen committed
296
        collective_op=collective_op_set.backward,
297
298
299
300
    )

    # GEMM TN
    # x_non_contracting_dims
301
    g_contracting_dim = x_contracting_dim = tuple(
302
303
304
305
        range(0, len(x_shape) - len(fwd_x_contracting_dims))
    )

    wgrad = tex.gemm(
306
307
        casted_x_lhs,
        casted_grad.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
308
        contracting_dims=(x_contracting_dim, g_contracting_dim),
309
        transpose_batch_sequence=transpose_batch_sequence,
310
    )
Phuong Nguyen's avatar
Phuong Nguyen committed
311
312

    dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
313
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
314
315
316
317
318
319
320
321

    return dgrad, wgrad, dbias, quantizer_set


_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)


def grouped_dense(
322
323
324
325
326
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    group_sizes: jnp.ndarray,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
    bias: jnp.ndarray = None,
327
    kernel_amax: jnp.ndarray = None,
328
329
330
331
    precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
    preferred_element_type: jnp.dtype = None,
    group_offset: jnp.array = None,
    quantizer_set: QuantizerSet = noop_quantizer_set,
332
    kernel_fsdp_info: Tuple[str, int] = (None, -1),
333
):
334
335
    """
    Perform grouped dense (linear) layer transformation with optional quantization.
336

337
338
339
340
341
342
343
    Args:
        x: Input tensor of shape (M, K)
        kernel: Weight matrix of shape (G, K, N)
        group_sizes: 1D array of shape (G,) specifying the size of each group
        contracting_dims: Tuple of sequences specifying which dimensions to contract
                          (currently only supports ((1,), (1,)))
        bias: Bias tensor of shape (G, N)
344
        kernel_amax: The amax values of weight matrix of shape (G,)
345
346
347
348
        precision: JAX precision for the GEMM operation
        preferred_element_type: Preferred data type for the output tensor
        group_offset: 1D array containing offsets for each group (not yet implemented)
        quantizer_set: Set of quantizers for FP8 quantization of the input and output
349
350
351
352
        kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix
                          represented in the format (str, int). The first element is the
                          FSDP mesh axis, and the second element is the dimension along
                          which the weight is sharded.
353
354
355
356
357
358
359
360
361
362

    Returns:
        A jnp.ndarray containing the result of the grouped linear operation
    """
    output = _grouped_dense(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
363
        kernel_amax,
364
365
366
367
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
368
        kernel_fsdp_info,
369
    )
370
    return output
371
372


373
@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10))
374
375
376
377
378
379
def _grouped_dense(
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
380
    kernel_amax,
381
382
383
384
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
385
    kernel_fsdp_info,
386
387
388
389
390
391
392
):
    output, _ = _grouped_dense_fwd_rule(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
393
        kernel_amax,
394
395
396
397
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
398
        kernel_fsdp_info,
399
    )
400
    return output
401
402
403


def _grouped_dense_fwd_rule(
404
405
406
407
408
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
409
    kernel_amax,
410
411
412
413
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
414
    kernel_fsdp_info,
415
):
416
417
418
    use_bias = bias is not None
    is_noop_quantizer_set = quantizer_set == noop_quantizer_set

419
420
421
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None

422
423
424
425
426
427
    if is_noop_quantizer_set:
        grouped_gemm_x = x
        grouped_gemm_kernel = kernel
        ctx_x = x
        ctx_kernel = kernel
        flatten_axis_k = None
428
429
430

        if kernel_fsdp_enabled:
            kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx)
431
    else:
432
433
        original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        x_contracting_dims, k_contracting_dims = contracting_dims
        flatten_axis_x = -len(x_contracting_dims)
        flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1  # +1 for G axis

        assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
        assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
        # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
            "grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
            "and k_contracting_dims=(1,) for now, "
            f"got {x_contracting_dims=} and {k_contracting_dims=}"
        )

        casted_x = tex.grouped_quantize(
449
450
451
452
            x,
            quantizer_set.x,
            group_sizes,
            flatten_axis=flatten_axis_x,
453
        )
454
455
456
457
458
459
460
461
462
463
464

        ctx_kernel_usage = TensorUsage.RHS_TRANS
        if kernel_fsdp_enabled:
            assert quantizer_set.kernel.scaling_mode in [
                ScalingMode.CURRENT_TENSOR_SCALING,
                ScalingMode.DELAYED_TENSOR_SCALING,
            ]
            # Perform `cast` only
            ctx_kernel_usage = TensorUsage.LHS
            quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE

465
        casted_kernel = tex.grouped_quantize(
466
            kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k
467
468
469
470
471
472
        )
        contracting_dims = (x_contracting_dims, k_contracting_dims)

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
        # rowwise_casted_x.original_shape == (M, K)
        # colwise_casted_kernel.original_shape == (G, N, K)
473
474
        grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
        ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
475
476
477
478
479
480
481
482
483
484
485
486
        ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage)

        if kernel_fsdp_enabled:
            ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape)
            global_ctx_kernel_data = _all_gather_kernel(
                ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
            )
            kernel_shape = global_ctx_kernel_data.shape

            ctx_kernel = ScaledTensorFactory.create_1x(
                global_ctx_kernel_data.reshape(-1),
                ctx_kernel.scale_inv,
487
                scaling_mode=ctx_kernel.scaling_mode,
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
                dq_dtype=ctx_kernel.dq_dtype,
                is_colwise=False,
                data_layout="N",
                flatten_axis=ctx_kernel.flatten_axis,
                group_sizes=ctx_kernel.group_sizes,
                original_shape=kernel_shape,
                group_axis=ctx_kernel.group_axis,
            )

            if is_fp8_gemm_with_all_layouts_supported():
                grouped_gemm_kernel = ctx_kernel
            else:
                grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1)
                grouped_gemm_kernel = ScaledTensorFactory.create_1x(
                    grouped_gemm_kernel_data.reshape(-1),
                    ctx_kernel.scale_inv,
504
                    scaling_mode=ctx_kernel.scaling_mode,
505
506
507
508
509
510
511
512
513
514
515
516
517
518
                    dq_dtype=ctx_kernel.dq_dtype,
                    is_colwise=True,
                    data_layout="T",
                    flatten_axis=ctx_kernel.flatten_axis,
                    group_sizes=ctx_kernel.group_sizes,
                    original_shape=kernel_shape,
                    group_axis=ctx_kernel.group_axis,
                )
        else:
            grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)

        # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one.
        # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled.
        quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout
519
520
521
522
523
524
525
526
527
528

    output = tex.grouped_gemm(
        grouped_gemm_x,
        grouped_gemm_kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
529
530
531
    )

    ctx = (
532
        group_sizes,
533
534
535
536
537
538
        ctx_x.checkpoint(quantizer_set.x) if isinstance(ctx_x, ScaledTensor) else ctx_x,
        (
            ctx_kernel.checkpoint(quantizer_set.kernel)
            if isinstance(ctx_kernel, ScaledTensor)
            else ctx_kernel
        ),
539
540
        x.shape,
        kernel.shape,
541
        use_bias,
542
543
544
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
545
    )
546
    return output, ctx
547
548


549
def _grouped_dense_bwd_rule(
550
    contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad
551
552
553
):
    fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims

554
    (
555
556
557
558
559
        group_sizes,
        ctx_x,
        ctx_kernel,
        x_shape,
        kernel_shape,
560
        use_bias,
561
562
563
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
564
565
    ) = ctx

566
567
568
569
    if is_noop_quantizer_set:
        # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
        # g_contracting_dim = (1, )
        # k_contracting_dim = (2, )
570
        g_contracting_dim = tuple(
571
            range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
572
573
        )
        k_contracting_dim = tuple(
574
            dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
575
576
        )
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
577
578
        dgrad_grad = grad
        dgrad_kernel_T = ctx_kernel
579

580
581
        # g_contracting_dim = (0, )
        # x_contracting_dim = (0, )
582
583
584
585
        g_contracting_dim = x_contracting_dim = tuple(
            range(0, len(x_shape) - len(fwd_x_contracting_dims))
        )
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        wgrad_x_T = ctx_x
        wgrad_grad = grad
    else:
        casted_grad = tex.grouped_quantize(
            grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
        )

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
        # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
        # extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (1,)
        k_contracting_dim = (2,)
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
600
        dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
601
602
        dgrad_kernel_T = ctx_kernel

603
        # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
604
605
606
        # after the extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (0,)
607
        x_contracting_dim = (0,)
608
609
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
        wgrad_x_T = ctx_x
610
        wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
611
612
613
614
615
616
617
618
619
620

    dgrad = tex.grouped_gemm(
        dgrad_grad,
        dgrad_kernel_T,
        group_sizes,
        dgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
    )
621

622
623
624
625
626
627
628
629
    wgrad = tex.grouped_gemm(
        wgrad_x_T,
        wgrad_grad,
        group_sizes,
        wgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
630
    )
631
632
633
634
635
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    if kernel_fsdp_mesh_axis is not None:
        wgrad = _psum_scatter_kernel(
            wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
        )
636

637
638
    group_sizes_grad = None
    dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
639
    dkernel_amax = None
640

641
    return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set
642
643
644


_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)