"vllm/vscode:/vscode.git/clone" did not exist on "7d761fe3c12e87df37383467c43c97dec2bb8470"
dense.py 21.5 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
11

12
13
from typing import Tuple, Sequence
from functools import partial
Phuong Nguyen's avatar
Phuong Nguyen committed
14
import warnings
15
16
17
18
import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
19
from .cpp_extensions.amax import AmaxScope
20
from .quantize import (
21
    ScaledTensorFactory,
22
    ScaledTensor,
23
    ScalingMode,
24
25
26
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
27
    is_fp8_gemm_with_all_layouts_supported,
28
    TensorUsage,
Paweł Gadziński's avatar
Paweł Gadziński committed
29
    QuantizeLayout,
30
)
31

Alp Dener's avatar
Alp Dener committed
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def _all_gather_kernel(kernel, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(kernel.shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel_shape = kernel.shape
    kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :])
    global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx)
    global_kernel = global_kernel.reshape(*kernel_whole_shape)
    return global_kernel


def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(scattered_kernel_shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel = kernel.reshape(
        *scattered_kernel_shape[:axis_idx],
        -1,
        scattered_kernel_shape[axis_idx],
        *scattered_kernel_shape[axis_idx + 1 :],
    )
    kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx)
    kernel = kernel.reshape(scattered_kernel_shape)
    return kernel


61
62
63
64
65
def dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    bias: jnp.ndarray = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
66
    transpose_batch_sequence: bool = False,
67
68
    input_axes: Tuple[str, ...] = None,
    kernel_axes: Tuple[str, ...] = None,
Phuong Nguyen's avatar
Phuong Nguyen committed
69
70
71
    output_axes: Tuple[str, ...] = None,
    collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
    quantizer_set: QuantizerSet = noop_quantizer_set,
72
73
74
75
76
77
78
79
80
81
82
83
):
    """Perform dense layer transformation with optional quantization.

    This function implements matrix multiplication with optional bias addition,
    supporting quantization and custom contracting dimensions. It's optimized
    for transformer architectures and supports automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix for the dense layer transformation
        bias: Optional bias tensor to add after the transformation
        contracting_dims: Tuple of sequences specifying which dimensions to contract
84
        transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
Phuong Nguyen's avatar
Phuong Nguyen committed
85
86
87
88
89
        input_axes: Logical axes for sharding the activation input
        kernel_axes: Logical axes for sharding the weight matrix
        output_axes: Logical axes for sharding the output
        collective_op_set: A set of CollectiveOp objects for forward and backward passes.
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
90
91
92
93

    Returns:
        Transformed output tensor
    """
94
95
    if transpose_batch_sequence:
        warnings.warn("transpose_batch_sequence is not well tested, use with caution!")
Phuong Nguyen's avatar
Phuong Nguyen committed
96

97
    if quantizer_set == noop_quantizer_set:
98
99
100
101
102
103
104
105
        input_dtype = x.dtype
        kernel = kernel.astype(input_dtype)

    output = _dense(
        x,
        kernel,
        bias,
        contracting_dims,
106
        transpose_batch_sequence,
107
108
        input_axes,
        kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
109
110
111
        output_axes,
        collective_op_set,
        quantizer_set,
112
    )
113
114
115
    return output


116
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8))
117
118
119
120
121
def _dense(
    x,
    kernel,
    bias,
    contracting_dims,
122
    transpose_batch_sequence,
123
124
    input_axes,
    kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
125
126
127
    output_axes,
    collective_op_set,
    quantizer_set,  # need to be a diff_arg for DelayedScaling state management
128
):
129
130
131
132
133
134
135
136
137
138
    """Internal implementation of dense layer transformation with custom VJP.

    This function implements the core dense layer transformation logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix
        bias: Optional bias tensor
        contracting_dims: Contracting dimensions specification
139
        transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor.
140
        input_axes: Logical axes for sharding the activation input
Phuong Nguyen's avatar
Phuong Nguyen committed
141
        output_axes: Logical axes for sharding the output_axes
142
        kernel_axes: Logical axes for sharding the weight matrix
Phuong Nguyen's avatar
Phuong Nguyen committed
143
144
        collective_op_set: A set of CollectiveOp objects for forward and backward passes.
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
145
146
147
148

    Returns:
        Transformed output tensor
    """
149
    output, _ = _dense_fwd_rule(
150
151
152
153
        x,
        kernel,
        bias,
        contracting_dims,
154
        transpose_batch_sequence,
155
156
        input_axes,
        kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
157
158
159
        output_axes,
        collective_op_set,
        quantizer_set,
160
    )
161
162
163
    return output


Alp Dener's avatar
Alp Dener committed
164
def _dense_fwd_rule(
165
166
167
168
    x,
    kernel,
    bias,
    contracting_dims,
169
    transpose_batch_sequence,
170
171
    input_axes,
    kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
172
173
174
    output_axes,
    collective_op_set,
    quantizer_set,
Alp Dener's avatar
Alp Dener committed
175
):
176
177
178
179
180
    """Forward pass rule for dense layer transformation.

    Returns:
        Tuple of (output, context) for backward pass
    """
Alp Dener's avatar
Alp Dener committed
181
182
183
184
185
186
187
188
189
190
191
    x_contracting_dims, k_contracting_dims = map(
        tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
    )

    # Check supported input layout
    x_is_transposed = x.ndim - 1 not in x_contracting_dims
    k_is_transposed = kernel.ndim - 1 in k_contracting_dims
    assert (
        not x_is_transposed and not k_is_transposed
    ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."

192
193
194
    flatten_axis_x = -len(x_contracting_dims)
    flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)

Alp Dener's avatar
Alp Dener committed
195
    casted_x = tex.quantize(
196
197
198
        x,
        flatten_axis=flatten_axis_x,
        quantizer=quantizer_set.x,
199
        amax_scope=AmaxScope.TPSP,
200
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
201
    )
202
203
204
    casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)

    casted_kernel = tex.quantize(
Alp Dener's avatar
Alp Dener committed
205
206
207
        kernel,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.kernel,
208
        amax_scope=AmaxScope.FSDP,
209
210
    )
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
211
212

    # GEMM NN
Alp Dener's avatar
Alp Dener committed
213
    use_bias = bias is not None
214
    output = tex.gemm(
215
216
        casted_x.get_tensor(usage=TensorUsage.LHS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
217
        contracting_dims=(x_contracting_dims, k_contracting_dims),
218
        transpose_batch_sequence=transpose_batch_sequence,
Alp Dener's avatar
Alp Dener committed
219
220
        bias=bias if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
221
        collective_op=collective_op_set.forward,
222
    )
Phuong Nguyen's avatar
Phuong Nguyen committed
223
    output = with_sharding_constraint_by_logical_axes(output, output_axes)
224

Alp Dener's avatar
Alp Dener committed
225
    if use_bias and tex.gemm_uses_jax_dot():
226
227
228
229
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
230
231
        casted_x.get_tensor(usage=TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
        casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
232
233
234
235
        x.shape,
        kernel.shape,
        use_bias,
        quantizer_set,
236
        flatten_axis_k,
237
238
239
240
    )
    return output, ctx


241
def _dense_bwd_rule(
Phuong Nguyen's avatar
Phuong Nguyen committed
242
    contracting_dims,
243
    transpose_batch_sequence,
Phuong Nguyen's avatar
Phuong Nguyen committed
244
245
246
247
248
249
250
    input_axes,
    kernel_axes,
    output_axes,
    collective_op_set,
    ctx,
    grad,
):
251
252
253
254
255
256
    """Backward pass rule for dense layer transformation.

    Returns:
        Tuple of gradients with respect to inputs
    """
    (
257
258
        casted_x_lhs,
        casted_kernel_rhs,
259
260
261
262
        x_shape,
        kernel_shape,
        use_bias,
        quantizer_set,
263
        flatten_axis_k,
264
    ) = ctx
Phuong Nguyen's avatar
Phuong Nguyen committed
265
    grad = with_sharding_constraint_by_logical_axes(grad, output_axes)
266

Alp Dener's avatar
Alp Dener committed
267
268
269
270
    fwd_x_contracting_dims, fwd_k_contracting_dims = map(
        tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
    )

271
    casted_grad, dbias = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
272
273
274
275
        grad,
        is_dbias=use_bias,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.dgrad,
276
        amax_scope=AmaxScope.TPSP,
277
        transpose_batch_sequence=transpose_batch_sequence,
278
    )
279
280
281

    # GEMM NT
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
282
    g_contracting_dim = tuple(
283
284
285
        range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
    )
    # k_non_contracting_dims
286
    k_contracting_dim = tuple(
287
288
        dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
    )
289

290
    dgrad = tex.gemm(
291
292
        casted_grad.get_tensor(usage=TensorUsage.LHS),
        casted_kernel_rhs,
Alp Dener's avatar
Alp Dener committed
293
        contracting_dims=(g_contracting_dim, k_contracting_dim),
294
        transpose_batch_sequence=transpose_batch_sequence,
Phuong Nguyen's avatar
Phuong Nguyen committed
295
        collective_op=collective_op_set.backward,
296
297
298
299
    )

    # GEMM TN
    # x_non_contracting_dims
300
    g_contracting_dim = x_contracting_dim = tuple(
301
302
303
304
        range(0, len(x_shape) - len(fwd_x_contracting_dims))
    )

    wgrad = tex.gemm(
305
306
        casted_x_lhs,
        casted_grad.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
307
        contracting_dims=(x_contracting_dim, g_contracting_dim),
308
        transpose_batch_sequence=transpose_batch_sequence,
309
    )
Phuong Nguyen's avatar
Phuong Nguyen committed
310
311

    dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
312
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
313
314
315
316
317
318
319
320

    return dgrad, wgrad, dbias, quantizer_set


_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)


def grouped_dense(
321
322
323
324
325
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    group_sizes: jnp.ndarray,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
    bias: jnp.ndarray = None,
326
    kernel_amax: jnp.ndarray = None,
327
328
329
330
    precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
    preferred_element_type: jnp.dtype = None,
    group_offset: jnp.array = None,
    quantizer_set: QuantizerSet = noop_quantizer_set,
331
    kernel_fsdp_info: Tuple[str, int] = (None, -1),
332
):
333
334
    """
    Perform grouped dense (linear) layer transformation with optional quantization.
335

336
337
338
339
340
341
342
    Args:
        x: Input tensor of shape (M, K)
        kernel: Weight matrix of shape (G, K, N)
        group_sizes: 1D array of shape (G,) specifying the size of each group
        contracting_dims: Tuple of sequences specifying which dimensions to contract
                          (currently only supports ((1,), (1,)))
        bias: Bias tensor of shape (G, N)
343
        kernel_amax: The amax values of weight matrix of shape (G,)
344
345
346
347
        precision: JAX precision for the GEMM operation
        preferred_element_type: Preferred data type for the output tensor
        group_offset: 1D array containing offsets for each group (not yet implemented)
        quantizer_set: Set of quantizers for FP8 quantization of the input and output
348
349
350
351
        kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix
                          represented in the format (str, int). The first element is the
                          FSDP mesh axis, and the second element is the dimension along
                          which the weight is sharded.
352
353
354
355
356
357
358
359
360
361

    Returns:
        A jnp.ndarray containing the result of the grouped linear operation
    """
    output = _grouped_dense(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
362
        kernel_amax,
363
364
365
366
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
367
        kernel_fsdp_info,
368
    )
369
    return output
370
371


372
@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10))
373
374
375
376
377
378
def _grouped_dense(
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
379
    kernel_amax,
380
381
382
383
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
384
    kernel_fsdp_info,
385
386
387
388
389
390
391
):
    output, _ = _grouped_dense_fwd_rule(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
392
        kernel_amax,
393
394
395
396
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
397
        kernel_fsdp_info,
398
    )
399
    return output
400
401
402


def _grouped_dense_fwd_rule(
403
404
405
406
407
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
408
    kernel_amax,
409
410
411
412
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
413
    kernel_fsdp_info,
414
):
415
416
417
    use_bias = bias is not None
    is_noop_quantizer_set = quantizer_set == noop_quantizer_set

418
419
420
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None

421
422
423
424
425
426
    if is_noop_quantizer_set:
        grouped_gemm_x = x
        grouped_gemm_kernel = kernel
        ctx_x = x
        ctx_kernel = kernel
        flatten_axis_k = None
427
428
429

        if kernel_fsdp_enabled:
            kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx)
430
    else:
431
432
        original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        x_contracting_dims, k_contracting_dims = contracting_dims
        flatten_axis_x = -len(x_contracting_dims)
        flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1  # +1 for G axis

        assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
        assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
        # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
            "grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
            "and k_contracting_dims=(1,) for now, "
            f"got {x_contracting_dims=} and {k_contracting_dims=}"
        )

        casted_x = tex.grouped_quantize(
448
449
450
451
            x,
            quantizer_set.x,
            group_sizes,
            flatten_axis=flatten_axis_x,
452
        )
453
454
455
456
457
458
459
460
461
462
463

        ctx_kernel_usage = TensorUsage.RHS_TRANS
        if kernel_fsdp_enabled:
            assert quantizer_set.kernel.scaling_mode in [
                ScalingMode.CURRENT_TENSOR_SCALING,
                ScalingMode.DELAYED_TENSOR_SCALING,
            ]
            # Perform `cast` only
            ctx_kernel_usage = TensorUsage.LHS
            quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE

464
        casted_kernel = tex.grouped_quantize(
465
            kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k
466
467
468
469
470
471
        )
        contracting_dims = (x_contracting_dims, k_contracting_dims)

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
        # rowwise_casted_x.original_shape == (M, K)
        # colwise_casted_kernel.original_shape == (G, N, K)
472
473
        grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
        ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
474
475
476
477
478
479
480
481
482
483
484
485
        ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage)

        if kernel_fsdp_enabled:
            ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape)
            global_ctx_kernel_data = _all_gather_kernel(
                ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
            )
            kernel_shape = global_ctx_kernel_data.shape

            ctx_kernel = ScaledTensorFactory.create_1x(
                global_ctx_kernel_data.reshape(-1),
                ctx_kernel.scale_inv,
486
                scaling_mode=ctx_kernel.scaling_mode,
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
                dq_dtype=ctx_kernel.dq_dtype,
                is_colwise=False,
                data_layout="N",
                flatten_axis=ctx_kernel.flatten_axis,
                group_sizes=ctx_kernel.group_sizes,
                original_shape=kernel_shape,
                group_axis=ctx_kernel.group_axis,
            )

            if is_fp8_gemm_with_all_layouts_supported():
                grouped_gemm_kernel = ctx_kernel
            else:
                grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1)
                grouped_gemm_kernel = ScaledTensorFactory.create_1x(
                    grouped_gemm_kernel_data.reshape(-1),
                    ctx_kernel.scale_inv,
503
                    scaling_mode=ctx_kernel.scaling_mode,
504
505
506
507
508
509
510
511
512
513
514
515
516
517
                    dq_dtype=ctx_kernel.dq_dtype,
                    is_colwise=True,
                    data_layout="T",
                    flatten_axis=ctx_kernel.flatten_axis,
                    group_sizes=ctx_kernel.group_sizes,
                    original_shape=kernel_shape,
                    group_axis=ctx_kernel.group_axis,
                )
        else:
            grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)

        # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one.
        # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled.
        quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout
518
519
520
521
522
523
524
525
526
527

    output = tex.grouped_gemm(
        grouped_gemm_x,
        grouped_gemm_kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
528
529
530
    )

    ctx = (
531
        group_sizes,
532
533
534
535
536
537
        ctx_x.checkpoint(quantizer_set.x) if isinstance(ctx_x, ScaledTensor) else ctx_x,
        (
            ctx_kernel.checkpoint(quantizer_set.kernel)
            if isinstance(ctx_kernel, ScaledTensor)
            else ctx_kernel
        ),
538
539
        x.shape,
        kernel.shape,
540
        use_bias,
541
542
543
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
544
    )
545
    return output, ctx
546
547


548
def _grouped_dense_bwd_rule(
549
    contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad
550
551
552
):
    fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims

553
    (
554
555
556
557
558
        group_sizes,
        ctx_x,
        ctx_kernel,
        x_shape,
        kernel_shape,
559
        use_bias,
560
561
562
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
563
564
    ) = ctx

565
566
567
568
    if is_noop_quantizer_set:
        # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
        # g_contracting_dim = (1, )
        # k_contracting_dim = (2, )
569
        g_contracting_dim = tuple(
570
            range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
571
572
        )
        k_contracting_dim = tuple(
573
            dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
574
575
        )
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
576
577
        dgrad_grad = grad
        dgrad_kernel_T = ctx_kernel
578

579
580
        # g_contracting_dim = (0, )
        # x_contracting_dim = (0, )
581
582
583
584
        g_contracting_dim = x_contracting_dim = tuple(
            range(0, len(x_shape) - len(fwd_x_contracting_dims))
        )
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
585
586
587
588
589
590
591
592
593
594
595
596
597
598
        wgrad_x_T = ctx_x
        wgrad_grad = grad
    else:
        casted_grad = tex.grouped_quantize(
            grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
        )

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
        # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
        # extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (1,)
        k_contracting_dim = (2,)
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
599
        dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
600
601
        dgrad_kernel_T = ctx_kernel

602
        # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
603
604
605
        # after the extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (0,)
606
        x_contracting_dim = (0,)
607
608
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
        wgrad_x_T = ctx_x
609
        wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
610
611
612
613
614
615
616
617
618
619

    dgrad = tex.grouped_gemm(
        dgrad_grad,
        dgrad_kernel_T,
        group_sizes,
        dgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
    )
620

621
622
623
624
625
626
627
628
    wgrad = tex.grouped_gemm(
        wgrad_x_T,
        wgrad_grad,
        group_sizes,
        wgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
629
    )
630
631
632
633
634
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    if kernel_fsdp_mesh_axis is not None:
        wgrad = _psum_scatter_kernel(
            wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
        )
635

636
637
    group_sizes_grad = None
    dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
638
    dkernel_amax = None
639

640
    return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set
641
642
643


_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)