dense.py 20 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.

This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
11

12
13
14
15
16
17
from typing import Tuple, Sequence
from functools import partial
import jax
import jax.numpy as jnp

from . import cpp_extensions as tex
18
from .cpp_extensions.quantization import AmaxScope
19
from .quantize import (
20
21
22
    ScaledTensorFactory,
    ScalingMode,
    QuantizeLayout,
23
24
25
    QuantizerSet,
    noop_quantizer_set,
    with_sharding_constraint_by_logical_axes,
26
    is_fp8_gemm_with_all_layouts_supported,
27
    TensorUsage,
28
    get_quantize_config,
29
)
30

Alp Dener's avatar
Alp Dener committed
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def _all_gather_kernel(kernel, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(kernel.shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel_shape = kernel.shape
    kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :])
    global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx)
    global_kernel = global_kernel.reshape(*kernel_whole_shape)
    return global_kernel


def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx):
    assert mesh_axis is not None
    assert 0 < axis_idx < len(scattered_kernel_shape)

    # TODO(Ming Hunag): Add a condition branch for with/without shmap.
    kernel = kernel.reshape(
        *scattered_kernel_shape[:axis_idx],
        -1,
        scattered_kernel_shape[axis_idx],
        *scattered_kernel_shape[axis_idx + 1 :],
    )
    kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx)
    kernel = kernel.reshape(scattered_kernel_shape)
    return kernel


60
61
62
63
64
def dense(
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    bias: jnp.ndarray = None,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
65
66
    input_axes: Tuple[str, ...] = None,
    kernel_axes: Tuple[str, ...] = None,
67
    quantizer_set: QuantizerSet = noop_quantizer_set,
68
    using_global_amax_of_x: bool = False,
69
70
71
72
73
74
75
76
77
78
79
80
81
):
    """Perform dense layer transformation with optional quantization.

    This function implements matrix multiplication with optional bias addition,
    supporting quantization and custom contracting dimensions. It's optimized
    for transformer architectures and supports automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix for the dense layer transformation
        bias: Optional bias tensor to add after the transformation
        contracting_dims: Tuple of sequences specifying which dimensions to contract
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
82
        using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
83
84
85
86

    Returns:
        Transformed output tensor
    """
87
88
89
90
91
92
93
94
95
96
97
98
    if not get_quantize_config().is_fp8_enabled():
        input_dtype = x.dtype
        kernel = kernel.astype(input_dtype)

    output = _dense(
        x,
        kernel,
        bias,
        contracting_dims,
        input_axes,
        kernel_axes,
        quantizer_set,
99
        using_global_amax_of_x,
100
    )
101
102
103
    return output


104
105
106
107
108
109
@partial(
    jax.custom_vjp,
    nondiff_argnums=(
        3,
        4,
        5,
110
        7,
111
112
    ),
)
113
114
115
116
117
118
119
120
def _dense(
    x,
    kernel,
    bias,
    contracting_dims,
    input_axes,
    kernel_axes,
    quantizer_set,
121
    using_global_amax_of_x,
122
):
123
124
125
126
127
128
129
130
131
132
    """Internal implementation of dense layer transformation with custom VJP.

    This function implements the core dense layer transformation logic with support
    for custom vector-Jacobian product (VJP) for automatic differentiation.

    Args:
        x: Input tensor
        kernel: Weight matrix
        bias: Optional bias tensor
        contracting_dims: Contracting dimensions specification
133
134
        input_axes: Logical axes for sharding the activation input
        kernel_axes: Logical axes for sharding the weight matrix
135
        quantizer_set: QuantizerSet which contains quantizers for different tensor types
136
        using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
137
138
139
140

    Returns:
        Transformed output tensor
    """
141
    output, _ = _dense_fwd_rule(
142
143
144
145
146
147
148
        x,
        kernel,
        bias,
        contracting_dims,
        input_axes,
        kernel_axes,
        quantizer_set,
149
        using_global_amax_of_x,
150
    )
151
152
153
    return output


Alp Dener's avatar
Alp Dener committed
154
def _dense_fwd_rule(
155
156
157
158
159
160
161
    x,
    kernel,
    bias,
    contracting_dims,
    input_axes,
    kernel_axes,
    quantizer_set,
162
    using_global_amax_of_x,
Alp Dener's avatar
Alp Dener committed
163
):
164
165
166
167
168
    """Forward pass rule for dense layer transformation.

    Returns:
        Tuple of (output, context) for backward pass
    """
Alp Dener's avatar
Alp Dener committed
169
170
171
172
173
174
175
176
177
178
179
    x_contracting_dims, k_contracting_dims = map(
        tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
    )

    # Check supported input layout
    x_is_transposed = x.ndim - 1 not in x_contracting_dims
    k_is_transposed = kernel.ndim - 1 in k_contracting_dims
    assert (
        not x_is_transposed and not k_is_transposed
    ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."

180
181
182
    flatten_axis_x = -len(x_contracting_dims)
    flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)

Alp Dener's avatar
Alp Dener committed
183
    casted_x = tex.quantize(
184
185
186
        x,
        flatten_axis=flatten_axis_x,
        quantizer=quantizer_set.x,
187
        amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL,
Alp Dener's avatar
Alp Dener committed
188
    )
189
190
191
    casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)

    casted_kernel = tex.quantize(
Alp Dener's avatar
Alp Dener committed
192
193
194
        kernel,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.kernel,
195
        amax_scope=AmaxScope.FSDP,
196
197
    )
    casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
198
199

    # GEMM NN
Alp Dener's avatar
Alp Dener committed
200
    use_bias = bias is not None
201
    output = tex.gemm(
202
203
        casted_x.get_tensor(usage=TensorUsage.LHS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
204
205
206
        contracting_dims=(x_contracting_dims, k_contracting_dims),
        bias=bias if not tex.gemm_uses_jax_dot() else None,
        fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
207
    )
208

Alp Dener's avatar
Alp Dener committed
209
    if use_bias and tex.gemm_uses_jax_dot():
210
211
212
213
        bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
        output += jnp.reshape(bias, bias_new_shape)

    ctx = (
214
215
        casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
        casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
216
217
218
219
        x.shape,
        kernel.shape,
        use_bias,
        quantizer_set,
220
        flatten_axis_k,
221
222
223
224
    )
    return output, ctx


225
def _dense_bwd_rule(
226
    contracting_dims, input_axes, kernel_axes, using_global_amax_of_x, ctx, grad
227
):  # pylint: disable=unused-argument
228
229
230
231
232
233
    """Backward pass rule for dense layer transformation.

    Returns:
        Tuple of gradients with respect to inputs
    """
    (
234
235
        casted_x_lhs,
        casted_kernel_rhs,
236
237
238
239
        x_shape,
        kernel_shape,
        use_bias,
        quantizer_set,
240
        flatten_axis_k,
241
242
    ) = ctx

Alp Dener's avatar
Alp Dener committed
243
244
245
246
    fwd_x_contracting_dims, fwd_k_contracting_dims = map(
        tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
    )

247
    casted_grad, dbias = tex.quantize_dbias(
Alp Dener's avatar
Alp Dener committed
248
249
250
251
        grad,
        is_dbias=use_bias,
        flatten_axis=flatten_axis_k,
        quantizer=quantizer_set.dgrad,
252
        amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP,
253
    )
254
255
256

    # GEMM NT
    # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
257
    g_contracting_dim = tuple(
258
259
260
        range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
    )
    # k_non_contracting_dims
261
    k_contracting_dim = tuple(
262
263
        dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
    )
264

265
    dgrad = tex.gemm(
266
267
        casted_grad.get_tensor(usage=TensorUsage.LHS),
        casted_kernel_rhs,
Alp Dener's avatar
Alp Dener committed
268
        contracting_dims=(g_contracting_dim, k_contracting_dim),
269
    )
270
    dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
271
272
273

    # GEMM TN
    # x_non_contracting_dims
274
    g_contracting_dim = x_contracting_dim = tuple(
275
276
277
278
        range(0, len(x_shape) - len(fwd_x_contracting_dims))
    )

    wgrad = tex.gemm(
279
280
        casted_x_lhs,
        casted_grad.get_tensor(usage=TensorUsage.RHS),
Alp Dener's avatar
Alp Dener committed
281
        contracting_dims=(x_contracting_dim, g_contracting_dim),
282
    )
283
    wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
284
285
286
287
288
289
290
291

    return dgrad, wgrad, dbias, quantizer_set


_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)


def grouped_dense(
292
293
294
295
296
    x: jnp.ndarray,
    kernel: jnp.ndarray,
    group_sizes: jnp.ndarray,
    contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
    bias: jnp.ndarray = None,
297
    kernel_amax: jnp.ndarray = None,
298
299
300
301
    precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
    preferred_element_type: jnp.dtype = None,
    group_offset: jnp.array = None,
    quantizer_set: QuantizerSet = noop_quantizer_set,
302
    kernel_fsdp_info: Tuple[str, int] = (None, -1),
303
):
304
305
    """
    Perform grouped dense (linear) layer transformation with optional quantization.
306

307
308
309
310
311
312
313
    Args:
        x: Input tensor of shape (M, K)
        kernel: Weight matrix of shape (G, K, N)
        group_sizes: 1D array of shape (G,) specifying the size of each group
        contracting_dims: Tuple of sequences specifying which dimensions to contract
                          (currently only supports ((1,), (1,)))
        bias: Bias tensor of shape (G, N)
314
        kernel_amax: The amax values of weight matrix of shape (G,)
315
316
317
318
        precision: JAX precision for the GEMM operation
        preferred_element_type: Preferred data type for the output tensor
        group_offset: 1D array containing offsets for each group (not yet implemented)
        quantizer_set: Set of quantizers for FP8 quantization of the input and output
319
320
321
322
        kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix
                          represented in the format (str, int). The first element is the
                          FSDP mesh axis, and the second element is the dimension along
                          which the weight is sharded.
323
324
325
326
327
328
329
330
331
332

    Returns:
        A jnp.ndarray containing the result of the grouped linear operation
    """
    output = _grouped_dense(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
333
        kernel_amax,
334
335
336
337
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
338
        kernel_fsdp_info,
339
    )
340
    return output
341
342


343
@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10))
344
345
346
347
348
349
def _grouped_dense(
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
350
    kernel_amax,
351
352
353
354
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
355
    kernel_fsdp_info,
356
357
358
359
360
361
362
):
    output, _ = _grouped_dense_fwd_rule(
        x,
        kernel,
        group_sizes,
        contracting_dims,
        bias,
363
        kernel_amax,
364
365
366
367
        precision,
        preferred_element_type,
        group_offset,
        quantizer_set,
368
        kernel_fsdp_info,
369
    )
370
    return output
371
372
373


def _grouped_dense_fwd_rule(
374
375
376
377
378
    x,
    kernel,
    group_sizes,
    contracting_dims,
    bias,
379
    kernel_amax,
380
381
382
383
    precision,
    preferred_element_type,
    group_offset,
    quantizer_set,
384
    kernel_fsdp_info,
385
):
386
387
388
    use_bias = bias is not None
    is_noop_quantizer_set = quantizer_set == noop_quantizer_set

389
390
391
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None

392
393
394
395
396
397
    if is_noop_quantizer_set:
        grouped_gemm_x = x
        grouped_gemm_kernel = kernel
        ctx_x = x
        ctx_kernel = kernel
        flatten_axis_k = None
398
399
400

        if kernel_fsdp_enabled:
            kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx)
401
    else:
402
403
        original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        x_contracting_dims, k_contracting_dims = contracting_dims
        flatten_axis_x = -len(x_contracting_dims)
        flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1  # +1 for G axis

        assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
        assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
        # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
            "grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
            "and k_contracting_dims=(1,) for now, "
            f"got {x_contracting_dims=} and {k_contracting_dims=}"
        )

        casted_x = tex.grouped_quantize(
419
420
421
422
            x,
            quantizer_set.x,
            group_sizes,
            flatten_axis=flatten_axis_x,
423
        )
424
425
426
427
428
429
430
431
432
433
434

        ctx_kernel_usage = TensorUsage.RHS_TRANS
        if kernel_fsdp_enabled:
            assert quantizer_set.kernel.scaling_mode in [
                ScalingMode.CURRENT_TENSOR_SCALING,
                ScalingMode.DELAYED_TENSOR_SCALING,
            ]
            # Perform `cast` only
            ctx_kernel_usage = TensorUsage.LHS
            quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE

435
        casted_kernel = tex.grouped_quantize(
436
            kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k
437
438
439
440
441
442
        )
        contracting_dims = (x_contracting_dims, k_contracting_dims)

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
        # rowwise_casted_x.original_shape == (M, K)
        # colwise_casted_kernel.original_shape == (G, N, K)
443
444
        grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
        ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
445
446
447
448
449
450
451
452
453
454
455
456
        ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage)

        if kernel_fsdp_enabled:
            ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape)
            global_ctx_kernel_data = _all_gather_kernel(
                ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
            )
            kernel_shape = global_ctx_kernel_data.shape

            ctx_kernel = ScaledTensorFactory.create_1x(
                global_ctx_kernel_data.reshape(-1),
                ctx_kernel.scale_inv,
457
                scaling_mode=ctx_kernel.scaling_mode,
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
                dq_dtype=ctx_kernel.dq_dtype,
                is_colwise=False,
                data_layout="N",
                flatten_axis=ctx_kernel.flatten_axis,
                group_sizes=ctx_kernel.group_sizes,
                original_shape=kernel_shape,
                group_axis=ctx_kernel.group_axis,
            )

            if is_fp8_gemm_with_all_layouts_supported():
                grouped_gemm_kernel = ctx_kernel
            else:
                grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1)
                grouped_gemm_kernel = ScaledTensorFactory.create_1x(
                    grouped_gemm_kernel_data.reshape(-1),
                    ctx_kernel.scale_inv,
474
                    scaling_mode=ctx_kernel.scaling_mode,
475
476
477
478
479
480
481
482
483
484
485
486
487
488
                    dq_dtype=ctx_kernel.dq_dtype,
                    is_colwise=True,
                    data_layout="T",
                    flatten_axis=ctx_kernel.flatten_axis,
                    group_sizes=ctx_kernel.group_sizes,
                    original_shape=kernel_shape,
                    group_axis=ctx_kernel.group_axis,
                )
        else:
            grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)

        # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one.
        # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled.
        quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout
489
490
491
492
493
494
495
496
497
498

    output = tex.grouped_gemm(
        grouped_gemm_x,
        grouped_gemm_kernel,
        group_sizes,
        contracting_dims,
        bias,
        precision,
        preferred_element_type,
        group_offset,
499
500
501
    )

    ctx = (
502
503
504
505
506
        group_sizes,
        ctx_x,
        ctx_kernel,
        x.shape,
        kernel.shape,
507
        use_bias,
508
509
510
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
511
    )
512
    return output, ctx
513
514


515
def _grouped_dense_bwd_rule(
516
    contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad
517
518
519
):
    fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims

520
    (
521
522
523
524
525
        group_sizes,
        ctx_x,
        ctx_kernel,
        x_shape,
        kernel_shape,
526
        use_bias,
527
528
529
        is_noop_quantizer_set,
        quantizer_set,
        flatten_axis_k,
530
531
    ) = ctx

532
533
534
535
    if is_noop_quantizer_set:
        # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
        # g_contracting_dim = (1, )
        # k_contracting_dim = (2, )
536
        g_contracting_dim = tuple(
537
            range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
538
539
        )
        k_contracting_dim = tuple(
540
            dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
541
542
        )
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
543
544
        dgrad_grad = grad
        dgrad_kernel_T = ctx_kernel
545

546
547
        # g_contracting_dim = (0, )
        # x_contracting_dim = (0, )
548
549
550
551
        g_contracting_dim = x_contracting_dim = tuple(
            range(0, len(x_shape) - len(fwd_x_contracting_dims))
        )
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        wgrad_x_T = ctx_x
        wgrad_grad = grad
    else:
        casted_grad = tex.grouped_quantize(
            grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
        )

        # For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
        # g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
        # extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (1,)
        k_contracting_dim = (2,)
        dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
566
        dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
567
568
        dgrad_kernel_T = ctx_kernel

569
        # We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
570
571
572
        # after the extra transpose for FP8 in grouped_gemm
        # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
        g_contracting_dim = (0,)
573
        x_contracting_dim = (0,)
574
575
        wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
        wgrad_x_T = ctx_x
576
        wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
577
578
579
580
581
582
583
584
585
586

    dgrad = tex.grouped_gemm(
        dgrad_grad,
        dgrad_kernel_T,
        group_sizes,
        dgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
    )
587

588
589
590
591
592
593
594
595
    wgrad = tex.grouped_gemm(
        wgrad_x_T,
        wgrad_grad,
        group_sizes,
        wgrad_contracting_dims,
        precision=precision,
        preferred_element_type=preferred_element_type,
        group_offset=group_offset,
596
    )
597
598
599
600
601
    kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
    if kernel_fsdp_mesh_axis is not None:
        wgrad = _psum_scatter_kernel(
            wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
        )
602

603
604
    group_sizes_grad = None
    dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
605
    dkernel_amax = None
606

607
    return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set
608
609
610


_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)