quantization.py 20.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional
8
from packaging import version
9

10
import jax
11
12
import jax.numpy as jnp
from jax import dtypes
13
from jax.sharding import PartitionSpec
14

15
import transformer_engine_jax
16
17
18
19
20

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
21
    te_dtype_to_jax_dtype,
22
    jax_dtype_to_te_dtype,
23
24
25
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
26
)
27
28
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
29
from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
30

31
32
33
34
35
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

36

37
__all__ = ["quantize", "quantize_dbias"]
38
39


40
class DBiasQuantizePrimitive(BasePrimitive):
41
    """
42
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
43
44
    """

45
    name = "te_dbias_quantize_ffi"
46
    multiple_results = True
47
48
49
50
51
52
53
54
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
55
56
        9,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
57
58
59
60
    inner_primitive = None
    outer_primitive = None

    @staticmethod
61
62
63
64
65
66
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
67
68
        q_layout,
        flatten_axis,
69
70
71
72
73
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
74
        """
75
        te_dbias_quantize_p abstract
76
        """
77
        del scale_shapes
78
79
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
80
        out_shape = x_aval.shape
81
        assert scale_aval is None or scale_aval.dtype == jnp.float32
82

83
84
85
86
87
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
88
89
90
91
92

        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
93
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
94

95
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
96
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
97
98
99
100
101
102
103
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
104
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
105
106
107
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
108
109

        if is_dbias:
110
111
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
112
113
114
115
116
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
117
118
119
120
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
121
            )
122
123
124
125
126
127
128
129
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
130
131
132
133
134
135
136
137
138
139

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
140
141

    @staticmethod
142
    def outer_abstract(*args, **kwargs):
143
        """
144
        te_dbias_quantize_p outer primitive abstract
145
        """
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
165
166
        q_layout,
        flatten_axis,
167
168
169
170
171
172
173
174
175
176
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype, scale_shapes, is_outer
        x_aval, scale_aval = ctx.avals_in
177
178
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
179
180
181
182
        return ffi.ffi_lowering(DBiasQuantizePrimitive.name)(
            ctx,
            x,
            scale,
183
            scaling_mode=scaling_mode.value,
184
185
            q_layout=q_layout,
            flatten_axis=flatten_axis,
186
187
            is_dbias=is_dbias,
        )
188
189

    @staticmethod
190
191
192
193
194
    def impl(
        x,
        scale,
        out_dtype,
        scaling_mode,
195
196
        q_layout,
        flatten_axis,
197
198
199
200
201
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
202
        """
203
        te_dbias_quantize_p implementation
204
        """
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        del is_outer
        assert DBiasQuantizePrimitive.inner_primitive is not None
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
220
221
            q_layout=q_layout,
            flatten_axis=flatten_axis,
222
223
224
225
            scale_dtype=scale_dtype,
            scale_shapes=scale_shapes,
            is_dbias=is_dbias,
            is_outer=False,
226
        )
227
228
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
229
230
231
232
233
234
235
236
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
237
238
239
240
241
242
243
244
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
245
246

    @staticmethod
247
248
249
250
251
252
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
253
254
        q_layout,
        flatten_axis,
255
256
257
258
259
260
261
262
263
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
264
        check_valid_batch_dims(batch_dims)
265
266
267
268
        assert DBiasQuantizePrimitive.outer_primitive is not None
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
269

270
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
271
        return (
272
273
274
275
276
            DBiasQuantizePrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
277
278
                q_layout=q_layout,
                flatten_axis=flatten_axis,
279
280
281
282
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
            ),
283
284
            out_bdims,
        )
285
286

    @staticmethod
287
288
289
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
290
291
        q_layout,
        flatten_axis,
292
293
294
295
296
297
298
299
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
300
        del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer)  # Unused.
301
        x_spec = get_padded_spec(arg_infos[0])
302
        scale_spec = get_padded_spec(arg_infos[1])
303
304
        out_sharding = NamedSharding(
            mesh,
305
            PartitionSpec(*x_spec),
306
307
            desc="DBiasQuantizePrimitive.out_sharding",
        )
308
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
309
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
310
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
311
312
313
314
315
316
317
318
319
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
320
321
322

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
323
            mesh,
324
325
            PartitionSpec(*dbias_spec),
            desc="DBiasQuantizePrimitive.dbias_sharding",
326
        )
327
328

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
329
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
330
            scale_inv_spec = amax_spec = scale_spec
331
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
332
333
334
335
336
337
338
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
339
        )
340
341
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
342
        )
343
        colwise_scale_inv_sharding = NamedSharding(
344
            mesh,
345
346
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DBiasQuantizePrimitive.colwise_scale_inv",
347
        )
348

349
350
351
352
353
354
355
356
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
357
358

    @staticmethod
359
360
361
    def partition(
        out_dtype,
        scaling_mode,
362
363
        q_layout,
        flatten_axis,
364
365
366
367
368
369
370
371
372
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
373
        x_spec = get_padded_spec(arg_infos[0])
374
        scale_spec = get_padded_spec(arg_infos[1])
375
376
        out_sharding = NamedSharding(
            mesh,
377
            PartitionSpec(*x_spec),
378
379
            desc="DBiasQuantizePrimitive.out_sharding",
        )
380
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
381
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
382
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
383
384
385
386
387
388
389
390
391
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
392
393
394

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
395
            mesh,
396
397
            PartitionSpec(*dbias_spec),
            desc="DBiasQuantizePrimitive.dbias_sharding",
398
        )
399
400

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
401
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
402
            scale_inv_spec = amax_spec = scale_spec
403
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
404
405
406
407
408
409
410
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
411
        )
412
413
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
414
        )
415
        colwise_scale_inv_sharding = NamedSharding(
416
            mesh,
417
418
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DBiasQuantizePrimitive.colwise_scale_inv",
419
        )
420

421
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
422
423
424
425
426
427
428
429
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
430

431
432
433
434
435
436
437
438
439
440
441
442
443
        def sharded_impl(x, scale):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
            ) = DBiasQuantizePrimitive.impl(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
444
445
                q_layout=q_layout,
                flatten_axis=flatten_axis,
446
447
448
449
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                is_outer=True,
450
            )
451

452
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
470
471
472
473

        return mesh, sharded_impl, out_shardings, arg_shardings


474
475
476
register_primitive(DBiasQuantizePrimitive)


477
478
479
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
480
481
    if quantizer is None:
        return x
482
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
483
484


485
486
487
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
    assert flatten_axis < 0
    dtype = dtype or dx.dtype
488
    dbias = jnp.sum(
489
490
        dx.astype(jnp.float32),
        axis=tuple(range(dx.ndim + flatten_axis)),
491
492
        keepdims=False,
    )
493
    return dbias.astype(dtype)
494
495
496
497
498
499


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
500
    flatten_axis: int = -1,
501
502
503
):
    if quantizer is None:
        return x, None
504
505
506
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
507
508
509
    )


510
def _quantize_dbias_impl(
511
    x: jnp.ndarray,
512
513
514
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
515
    flatten_axis: int = -1,
516
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
517
518
519
520
    """
    Cast wrapper
    Return FP8 tensor
    """
521
522
523
524
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

525
526
    dq_dtype = dq_dtype or x.dtype

527
528
529
530
531
532
    if not DBiasQuantizePrimitive.enabled():
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
533
                flatten_axis=flatten_axis,
534
            )
535
536
537
538
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
539
540

    # TE/common doesn't support colwise only quantization yet
541
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
542
543
544
545
546
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
547
                flatten_axis=flatten_axis,
548
            )
549
550
551
552
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
553
554
555
556
    scale = jnp.empty((), jnp.float32)

    # TE/common dbias_quantize does not support 1x on arch < 100
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
557
        out, _ = _quantize_dbias_impl(
558
559
560
561
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
562
            flatten_axis=flatten_axis,
563
        )
564
        dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
565
566
567
568
        return out, dbias

    if quantizer is None:
        if is_dbias:
569
            return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        return x, None

    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DBiasQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
587
588
        q_layout=quantizer.q_layout.value,
        flatten_axis=flatten_axis,
589
        scale_dtype=quantizer.get_scale_dtype(),
590
        scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
591
592
593
594
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
595
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
596
597
598
599
600
601
602
603
604
605
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
606
607
608
609
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
610
    )
611
    return out, dbias.astype(dq_dtype)
612
613
614
615
616


def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
617
    flatten_axis: int = -1,
618
619
620
621
622
623
624
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
625
626
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
627
628
629
630

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
631
    out, _ = _quantize_dbias_impl(
632
633
        x,
        quantizer=quantizer,
634
        flatten_axis=flatten_axis,
635
636
637
638
639
640
641
642
    )
    return out


def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
643
    flatten_axis: int = -1,
644
645
646
647
648
649
650
651
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
652
653
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
654
655
656
657
658
659
660
661

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
662
663
    return _quantize_dbias_impl(
        dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
664
    )