dense.py 21.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
11

12
13
from typing import Tuple, Sequence
from functools import partial
Phuong Nguyen's avatar
Phuong Nguyen committed
14
import warnings
15
16
17
18
import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
19
from .cpp_extensions.quantization import AmaxScope
20
from .quantize import (
21
22
23
    ScaledTensorFactory,
    ScalingMode,
    QuantizeLayout,
24
25
26
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
27
    is_fp8_gemm_with_all_layouts_supported,
28
    TensorUsage,
29
    get_quantize_config,
30
)
31

Alp Dener's avatar
Alp Dener committed
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def _all_gather_kernel(kernel, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(kernel.shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel_shape = kernel.shape
    kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :])
    global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx)
    global_kernel = global_kernel.reshape(*kernel_whole_shape)
    return global_kernel


def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(scattered_kernel_shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel = kernel.reshape(
        *scattered_kernel_shape[:axis_idx],
        -1,
        scattered_kernel_shape[axis_idx],
        *scattered_kernel_shape[axis_idx + 1 :],
    )
    kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx)
    kernel = kernel.reshape(scattered_kernel_shape)
    return kernel


61
62
63
64
65
def dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    bias: jnp.ndarray = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
Phuong Nguyen's avatar
Phuong Nguyen committed
66
    batch_sequence_transpose: bool = False,
67
68
    input_axes: Tuple[str, ...] = None,
    kernel_axes: Tuple[str, ...] = None,
Phuong Nguyen's avatar
Phuong Nguyen committed
69
    output_axes: Tuple[str, ...] = None,
70
    using_global_amax_of_x: bool = False,
Phuong Nguyen's avatar
Phuong Nguyen committed
71
72
    collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
    quantizer_set: QuantizerSet = noop_quantizer_set,
73
74
75
76
77
78
79
80
81
82
83
84
):
    """Perform dense layer transformation with optional quantization.

    This function implements matrix multiplication with optional bias addition,
    supporting quantization and custom contracting dimensions. It's optimized
    for transformer architectures and supports automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix for the dense layer transformation
        bias: Optional bias tensor to add after the transformation
        contracting_dims: Tuple of sequences specifying which dimensions to contract
Phuong Nguyen's avatar
Phuong Nguyen committed
85
86
87
88
        batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
        input_axes: Logical axes for sharding the activation input
        kernel_axes: Logical axes for sharding the weight matrix
        output_axes: Logical axes for sharding the output
89
        using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
Phuong Nguyen's avatar
Phuong Nguyen committed
90
91
        collective_op_set: A set of CollectiveOp objects for forward and backward passes.
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
92
93
94
95

    Returns:
        Transformed output tensor
    """
Phuong Nguyen's avatar
Phuong Nguyen committed
96
97
98
    if batch_sequence_transpose:
        warnings.warn("batch_sequence_transpose is not well tested, use with caution!")

99
100
101
102
103
104
105
106
107
    if not get_quantize_config().is_fp8_enabled():
        input_dtype = x.dtype
        kernel = kernel.astype(input_dtype)

    output = _dense(
        x,
        kernel,
        bias,
        contracting_dims,
Phuong Nguyen's avatar
Phuong Nguyen committed
108
        batch_sequence_transpose,
109
110
        input_axes,
        kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
111
        output_axes,
112
        using_global_amax_of_x,
Phuong Nguyen's avatar
Phuong Nguyen committed
113
114
        collective_op_set,
        quantizer_set,
115
    )
116
117
118
    return output


Phuong Nguyen's avatar
Phuong Nguyen committed
119
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9))
120
121
122
123
124
def _dense(
    x,
    kernel,
    bias,
    contracting_dims,
Phuong Nguyen's avatar
Phuong Nguyen committed
125
    batch_sequence_transpose,
126
127
    input_axes,
    kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
128
    output_axes,
129
    using_global_amax_of_x,
Phuong Nguyen's avatar
Phuong Nguyen committed
130
131
    collective_op_set,
    quantizer_set,  # need to be a diff_arg for DelayedScaling state management
132
):
133
134
135
136
137
138
139
140
141
142
    """Internal implementation of dense layer transformation with custom VJP.

    This function implements the core dense layer transformation logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix
        bias: Optional bias tensor
        contracting_dims: Contracting dimensions specification
Phuong Nguyen's avatar
Phuong Nguyen committed
143
        batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
144
        input_axes: Logical axes for sharding the activation input
Phuong Nguyen's avatar
Phuong Nguyen committed
145
        output_axes: Logical axes for sharding the output_axes
146
        kernel_axes: Logical axes for sharding the weight matrix
147
        using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
Phuong Nguyen's avatar
Phuong Nguyen committed
148
149
        collective_op_set: A set of CollectiveOp objects for forward and backward passes.
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
150
151
152
153

    Returns:
        Transformed output tensor
    """
154
    output, _ = _dense_fwd_rule(
155
156
157
158
        x,
        kernel,
        bias,
        contracting_dims,
Phuong Nguyen's avatar
Phuong Nguyen committed
159
        batch_sequence_transpose,
160
161
        input_axes,
        kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
162
        output_axes,
163
        using_global_amax_of_x,
Phuong Nguyen's avatar
Phuong Nguyen committed
164
165
        collective_op_set,
        quantizer_set,
166
    )
167
168
169
    return output


Alp Dener's avatar
Alp Dener committed
170
def _dense_fwd_rule(
171
172
173
174
    x,
    kernel,
    bias,
    contracting_dims,
Phuong Nguyen's avatar
Phuong Nguyen committed
175
    batch_sequence_transpose,
176
177
    input_axes,
    kernel_axes,
Phuong Nguyen's avatar
Phuong Nguyen committed
178
    output_axes,
179
    using_global_amax_of_x,
Phuong Nguyen's avatar
Phuong Nguyen committed
180
181
    collective_op_set,
    quantizer_set,
Alp Dener's avatar
Alp Dener committed
182
):
183
184
185
186
187
    """Forward pass rule for dense layer transformation.

    Returns:
        Tuple of (output, context) for backward pass
    """
Alp Dener's avatar
Alp Dener committed
188
189
190
191
192
193
194
195
196
197
198
    x_contracting_dims, k_contracting_dims = map(
        tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
    )

    # Check supported input layout
    x_is_transposed = x.ndim - 1 not in x_contracting_dims
    k_is_transposed = kernel.ndim - 1 in k_contracting_dims
    assert (
        not x_is_transposed and not k_is_transposed
    ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."

199
200
201
    flatten_axis_x = -len(x_contracting_dims)
    flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)

Alp Dener's avatar
Alp Dener committed
202
    casted_x = tex.quantize(
203
204
205
        x,
        flatten_axis=flatten_axis_x,
        quantizer=quantizer_set.x,
206
        amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL,
Alp Dener's avatar
Alp Dener committed
207
    )
208
209
210
    casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)

    casted_kernel = tex.quantize(
Alp Dener's avatar
Alp Dener committed
211
212
213
        kernel,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.kernel,
214
        amax_scope=AmaxScope.FSDP,
215
216
    )
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
217
218

    # GEMM NN
Alp Dener's avatar
Alp Dener committed
219
    use_bias = bias is not None
220
    output = tex.gemm(
221
222
        casted_x.get_tensor(usage=TensorUsage.LHS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
223
        contracting_dims=(x_contracting_dims, k_contracting_dims),
Phuong Nguyen's avatar
Phuong Nguyen committed
224
        transpose_batch_sequence=batch_sequence_transpose,
Alp Dener's avatar
Alp Dener committed
225
226
        bias=bias if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
Phuong Nguyen's avatar
Phuong Nguyen committed
227
        collective_op=collective_op_set.forward,
228
    )
Phuong Nguyen's avatar
Phuong Nguyen committed
229
    output = with_sharding_constraint_by_logical_axes(output, output_axes)
230

Alp Dener's avatar
Alp Dener committed
231
    if use_bias and tex.gemm_uses_jax_dot():
232
233
234
235
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
236
237
        casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
238
239
240
241
        x.shape,
        kernel.shape,
        use_bias,
        quantizer_set,
242
        flatten_axis_k,
243
244
245
246
    )
    return output, ctx


247
def _dense_bwd_rule(
Phuong Nguyen's avatar
Phuong Nguyen committed
248
249
250
251
252
253
254
255
256
257
    contracting_dims,
    batch_sequence_transpose,
    input_axes,
    kernel_axes,
    output_axes,
    using_global_amax_of_x,
    collective_op_set,
    ctx,
    grad,
):
258
259
260
261
262
263
    """Backward pass rule for dense layer transformation.

    Returns:
        Tuple of gradients with respect to inputs
    """
    (
264
265
        casted_x_lhs,
        casted_kernel_rhs,
266
267
268
269
        x_shape,
        kernel_shape,
        use_bias,
        quantizer_set,
270
        flatten_axis_k,
271
    ) = ctx
Phuong Nguyen's avatar
Phuong Nguyen committed
272
    grad = with_sharding_constraint_by_logical_axes(grad, output_axes)
273

Alp Dener's avatar
Alp Dener committed
274
275
276
277
    fwd_x_contracting_dims, fwd_k_contracting_dims = map(
        tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
    )

278
    casted_grad, dbias = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
279
280
281
282
        grad,
        is_dbias=use_bias,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.dgrad,
283
        amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP,
284
    )
285
286
287

    # GEMM NT
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
288
    g_contracting_dim = tuple(
289
290
291
        range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
    )
    # k_non_contracting_dims
292
    k_contracting_dim = tuple(
293
294
        dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
    )
295

296
    dgrad = tex.gemm(
297
298
        casted_grad.get_tensor(usage=TensorUsage.LHS),
        casted_kernel_rhs,
Alp Dener's avatar
Alp Dener committed
299
        contracting_dims=(g_contracting_dim, k_contracting_dim),
Phuong Nguyen's avatar
Phuong Nguyen committed
300
301
        transpose_batch_sequence=batch_sequence_transpose,
        collective_op=collective_op_set.backward,
302
303
304
305
    )

    # GEMM TN
    # x_non_contracting_dims
306
    g_contracting_dim = x_contracting_dim = tuple(
307
308
309
310
        range(0, len(x_shape) - len(fwd_x_contracting_dims))
    )

    wgrad = tex.gemm(
311
312
        casted_x_lhs,
        casted_grad.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
313
        contracting_dims=(x_contracting_dim, g_contracting_dim),
Phuong Nguyen's avatar
Phuong Nguyen committed
314
        transpose_batch_sequence=batch_sequence_transpose,
315
    )
Phuong Nguyen's avatar
Phuong Nguyen committed
316
317

    dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
318
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
319
320
321
322
323
324
325
326

    return dgrad, wgrad, dbias, quantizer_set


_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)


def grouped_dense(
327
328
329
330
331
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    group_sizes: jnp.ndarray,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
    bias: jnp.ndarray = None,
332
    kernel_amax: jnp.ndarray = None,
333
334
335
336
    precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
    preferred_element_type: jnp.dtype = None,
    group_offset: jnp.array = None,
    quantizer_set: QuantizerSet = noop_quantizer_set,
337
    kernel_fsdp_info: Tuple[str, int] = (None, -1),
338
):
339
340
    """
    Perform grouped dense (linear) layer transformation with optional quantization.
341

342
343
344
345
346
347
348
    Args:
        x: Input tensor of shape (M, K)
        kernel: Weight matrix of shape (G, K, N)
        group_sizes: 1D array of shape (G,) specifying the size of each group
        contracting_dims: Tuple of sequences specifying which dimensions to contract
                          (currently only supports ((1,), (1,)))
        bias: Bias tensor of shape (G, N)
349
        kernel_amax: The amax values of weight matrix of shape (G,)
350
351
352
353
        precision: JAX precision for the GEMM operation
        preferred_element_type: Preferred data type for the output tensor
        group_offset: 1D array containing offsets for each group (not yet implemented)
        quantizer_set: Set of quantizers for FP8 quantization of the input and output
354
355
356
357
        kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix
                          represented in the format (str, int). The first element is the
                          FSDP mesh axis, and the second element is the dimension along
                          which the weight is sharded.
358
359
360
361
362
363
364
365
366
367

    Returns:
        A jnp.ndarray containing the result of the grouped linear operation
    """
    output = _grouped_dense(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
368
        kernel_amax,
369
370
371
372
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
373
        kernel_fsdp_info,
374
    )
375
    return output
376
377


378
@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10))
379
380
381
382
383
384
def _grouped_dense(
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
385
    kernel_amax,
386
387
388
389
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
390
    kernel_fsdp_info,
391
392
393
394
395
396
397
):
    output, _ = _grouped_dense_fwd_rule(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
398
        kernel_amax,
399
400
401
402
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
403
        kernel_fsdp_info,
404
    )
405
    return output
406
407
408


def _grouped_dense_fwd_rule(
409
410
411
412
413
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
414
    kernel_amax,
415
416
417
418
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
419
    kernel_fsdp_info,
420
):
421
422
423
    use_bias = bias is not None
    is_noop_quantizer_set = quantizer_set == noop_quantizer_set

424
425
426
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None

427
428
429
430
431
432
    if is_noop_quantizer_set:
        grouped_gemm_x = x
        grouped_gemm_kernel = kernel
        ctx_x = x
        ctx_kernel = kernel
        flatten_axis_k = None
433
434
435

        if kernel_fsdp_enabled:
            kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx)
436
    else:
437
438
        original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout

439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        x_contracting_dims, k_contracting_dims = contracting_dims
        flatten_axis_x = -len(x_contracting_dims)
        flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1  # +1 for G axis

        assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
        assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
        # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
            "grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
            "and k_contracting_dims=(1,) for now, "
            f"got {x_contracting_dims=} and {k_contracting_dims=}"
        )

        casted_x = tex.grouped_quantize(
454
455
456
457
            x,
            quantizer_set.x,
            group_sizes,
            flatten_axis=flatten_axis_x,
458
        )
459
460
461
462
463
464
465
466
467
468
469

        ctx_kernel_usage = TensorUsage.RHS_TRANS
        if kernel_fsdp_enabled:
            assert quantizer_set.kernel.scaling_mode in [
                ScalingMode.CURRENT_TENSOR_SCALING,
                ScalingMode.DELAYED_TENSOR_SCALING,
            ]
            # Perform `cast` only
            ctx_kernel_usage = TensorUsage.LHS
            quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE

470
        casted_kernel = tex.grouped_quantize(
471
            kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k
472
473
474
475
476
477
        )
        contracting_dims = (x_contracting_dims, k_contracting_dims)

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
        # rowwise_casted_x.original_shape == (M, K)
        # colwise_casted_kernel.original_shape == (G, N, K)
478
479
        grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
        ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
480
481
482
483
484
485
486
487
488
489
490
491
        ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage)

        if kernel_fsdp_enabled:
            ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape)
            global_ctx_kernel_data = _all_gather_kernel(
                ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
            )
            kernel_shape = global_ctx_kernel_data.shape

            ctx_kernel = ScaledTensorFactory.create_1x(
                global_ctx_kernel_data.reshape(-1),
                ctx_kernel.scale_inv,
492
                scaling_mode=ctx_kernel.scaling_mode,
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
                dq_dtype=ctx_kernel.dq_dtype,
                is_colwise=False,
                data_layout="N",
                flatten_axis=ctx_kernel.flatten_axis,
                group_sizes=ctx_kernel.group_sizes,
                original_shape=kernel_shape,
                group_axis=ctx_kernel.group_axis,
            )

            if is_fp8_gemm_with_all_layouts_supported():
                grouped_gemm_kernel = ctx_kernel
            else:
                grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1)
                grouped_gemm_kernel = ScaledTensorFactory.create_1x(
                    grouped_gemm_kernel_data.reshape(-1),
                    ctx_kernel.scale_inv,
509
                    scaling_mode=ctx_kernel.scaling_mode,
510
511
512
513
514
515
516
517
518
519
520
521
522
523
                    dq_dtype=ctx_kernel.dq_dtype,
                    is_colwise=True,
                    data_layout="T",
                    flatten_axis=ctx_kernel.flatten_axis,
                    group_sizes=ctx_kernel.group_sizes,
                    original_shape=kernel_shape,
                    group_axis=ctx_kernel.group_axis,
                )
        else:
            grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)

        # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one.
        # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled.
        quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout
524
525
526
527
528
529
530
531
532
533

    output = tex.grouped_gemm(
        grouped_gemm_x,
        grouped_gemm_kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
534
535
536
    )

    ctx = (
537
538
539
540
541
        group_sizes,
        ctx_x,
        ctx_kernel,
        x.shape,
        kernel.shape,
542
        use_bias,
543
544
545
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
546
    )
547
    return output, ctx
548
549


550
def _grouped_dense_bwd_rule(
551
    contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad
552
553
554
):
    fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims

555
    (
556
557
558
559
560
        group_sizes,
        ctx_x,
        ctx_kernel,
        x_shape,
        kernel_shape,
561
        use_bias,
562
563
564
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
565
566
    ) = ctx

567
568
569
570
    if is_noop_quantizer_set:
        # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
        # g_contracting_dim = (1, )
        # k_contracting_dim = (2, )
571
        g_contracting_dim = tuple(
572
            range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
573
574
        )
        k_contracting_dim = tuple(
575
            dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
576
577
        )
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
578
579
        dgrad_grad = grad
        dgrad_kernel_T = ctx_kernel
580

581
582
        # g_contracting_dim = (0, )
        # x_contracting_dim = (0, )
583
584
585
586
        g_contracting_dim = x_contracting_dim = tuple(
            range(0, len(x_shape) - len(fwd_x_contracting_dims))
        )
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        wgrad_x_T = ctx_x
        wgrad_grad = grad
    else:
        casted_grad = tex.grouped_quantize(
            grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
        )

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
        # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
        # extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (1,)
        k_contracting_dim = (2,)
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
601
        dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
602
603
        dgrad_kernel_T = ctx_kernel

604
        # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
605
606
607
        # after the extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (0,)
608
        x_contracting_dim = (0,)
609
610
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
        wgrad_x_T = ctx_x
611
        wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
612
613
614
615
616
617
618
619
620
621

    dgrad = tex.grouped_gemm(
        dgrad_grad,
        dgrad_kernel_T,
        group_sizes,
        dgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
    )
622

623
624
625
626
627
628
629
630
    wgrad = tex.grouped_gemm(
        wgrad_x_T,
        wgrad_grad,
        group_sizes,
        wgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
631
    )
632
633
634
635
636
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    if kernel_fsdp_mesh_axis is not None:
        wgrad = _psum_scatter_kernel(
            wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
        )
637

638
639
    group_sizes_grad = None
    dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
640
    dkernel_amax = None
641

642
    return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set
643
644
645


_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)