quantization.py 20.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional
8
from packaging import version
9

10
import jax
11
12
import jax.numpy as jnp
from jax import dtypes
13
from jax.sharding import PartitionSpec
14

15
import transformer_engine_jax
16
17
18
19
20

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
21
    te_dtype_to_jax_dtype,
22
    jax_dtype_to_te_dtype,
23
24
25
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
26
)
27
28
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
29
from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
30

31
32
33
34
35
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

36

37
__all__ = ["quantize", "quantize_dbias"]
38
39


40
class DBiasQuantizePrimitive(BasePrimitive):
41
    """
42
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
43
44
    """

45
    name = "te_dbias_quantize_ffi"
46
    multiple_results = True
47
48
49
50
51
52
53
54
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
55
56
        9,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
57
58
59
60
    inner_primitive = None
    outer_primitive = None

    @staticmethod
61
62
63
64
65
66
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
67
68
        q_layout,
        flatten_axis,
69
70
71
72
73
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
74
        """
75
        te_dbias_quantize_p abstract
76
        """
77
        del scale_shapes
78
79
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
80
        out_shape = x_aval.shape
81
        assert scale_aval is None or scale_aval.dtype == jnp.float32
82

83
84
85
86
87
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
88
89
90
91
92

        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
93
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
94

95
96
97
98
99
100
101
102
103
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
104
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
105
106
107
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
108
109

        if is_dbias:
110
111
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
112
113
114
115
116
117
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
            )
118
119
120
121
122
123
124
125
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
126
127
128
129
130
131
132
133
134
135

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
136
137

    @staticmethod
138
    def outer_abstract(*args, **kwargs):
139
        """
140
        te_dbias_quantize_p outer primitive abstract
141
        """
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
161
162
        q_layout,
        flatten_axis,
163
164
165
166
167
168
169
170
171
172
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype, scale_shapes, is_outer
        x_aval, scale_aval = ctx.avals_in
173
174
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
175
176
177
178
179
        return ffi.ffi_lowering(DBiasQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            scaling_mode=scaling_mode,
180
181
            q_layout=q_layout,
            flatten_axis=flatten_axis,
182
183
            is_dbias=is_dbias,
        )
184
185

    @staticmethod
186
187
188
189
190
    def impl(
        x,
        scale,
        out_dtype,
        scaling_mode,
191
192
        q_layout,
        flatten_axis,
193
194
195
196
197
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
198
        """
199
        te_dbias_quantize_p implementation
200
        """
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        del is_outer
        assert DBiasQuantizePrimitive.inner_primitive is not None
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
216
217
            q_layout=q_layout,
            flatten_axis=flatten_axis,
218
219
220
221
            scale_dtype=scale_dtype,
            scale_shapes=scale_shapes,
            is_dbias=is_dbias,
            is_outer=False,
222
        )
223
224
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
225
226
227
228
229
230
231
232
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
233
234
235
236
237
238
239
240
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
241
242

    @staticmethod
243
244
245
246
247
248
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
249
250
        q_layout,
        flatten_axis,
251
252
253
254
255
256
257
258
259
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
260
        check_valid_batch_dims(batch_dims)
261
262
263
264
        assert DBiasQuantizePrimitive.outer_primitive is not None
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
265

266
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
267
        return (
268
269
270
271
272
            DBiasQuantizePrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
273
274
                q_layout=q_layout,
                flatten_axis=flatten_axis,
275
276
277
278
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
            ),
279
280
            out_bdims,
        )
281
282

    @staticmethod
283
284
285
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
286
287
        q_layout,
        flatten_axis,
288
289
290
291
292
293
294
295
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
296
        del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer)  # Unused.
297
        x_spec = get_padded_spec(arg_infos[0])
298
        scale_spec = get_padded_spec(arg_infos[1])
299
300
        out_sharding = NamedSharding(
            mesh,
301
            PartitionSpec(*x_spec),
302
303
            desc="DBiasQuantizePrimitive.out_sharding",
        )
304
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
305
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
306
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
307
308
309
310
311
312
313
314
315
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
316
317
318

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
319
            mesh,
320
321
            PartitionSpec(*dbias_spec),
            desc="DBiasQuantizePrimitive.dbias_sharding",
322
        )
323
324
325
326
327
328
329
330
331
332
333
334

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
            scale_inv_spec = amax_spec = scale_spec
        elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
335
        )
336
337
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
338
        )
339
        colwise_scale_inv_sharding = NamedSharding(
340
            mesh,
341
342
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DBiasQuantizePrimitive.colwise_scale_inv",
343
        )
344

345
346
347
348
349
350
351
352
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
353
354

    @staticmethod
355
356
357
    def partition(
        out_dtype,
        scaling_mode,
358
359
        q_layout,
        flatten_axis,
360
361
362
363
364
365
366
367
368
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
369
        x_spec = get_padded_spec(arg_infos[0])
370
        scale_spec = get_padded_spec(arg_infos[1])
371
372
        out_sharding = NamedSharding(
            mesh,
373
            PartitionSpec(*x_spec),
374
375
            desc="DBiasQuantizePrimitive.out_sharding",
        )
376
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
377
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
378
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
379
380
381
382
383
384
385
386
387
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
388
389
390

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
391
            mesh,
392
393
            PartitionSpec(*dbias_spec),
            desc="DBiasQuantizePrimitive.dbias_sharding",
394
        )
395
396
397
398
399
400
401
402
403
404
405
406

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
            scale_inv_spec = amax_spec = scale_spec
        elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
407
        )
408
409
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
410
        )
411
        colwise_scale_inv_sharding = NamedSharding(
412
            mesh,
413
414
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DBiasQuantizePrimitive.colwise_scale_inv",
415
        )
416

417
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
418
419
420
421
422
423
424
425
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
426

427
428
429
430
431
432
433
434
435
436
437
438
439
        def sharded_impl(x, scale):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
            ) = DBiasQuantizePrimitive.impl(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
440
441
                q_layout=q_layout,
                flatten_axis=flatten_axis,
442
443
444
445
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                is_outer=True,
446
            )
447

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
466
467
468
469

        return mesh, sharded_impl, out_shardings, arg_shardings


470
471
472
register_primitive(DBiasQuantizePrimitive)


473
474
475
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
476
477
    if quantizer is None:
        return x
478
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
479
480


481
482
483
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
    assert flatten_axis < 0
    dtype = dtype or dx.dtype
484
    dbias = jnp.sum(
485
486
        dx.astype(jnp.float32),
        axis=tuple(range(dx.ndim + flatten_axis)),
487
488
        keepdims=False,
    )
489
    return dbias.astype(dtype)
490
491
492
493
494
495


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
496
    flatten_axis: int = -1,
497
498
499
):
    if quantizer is None:
        return x, None
500
501
502
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
503
504
505
    )


506
def _quantize_dbias_impl(
507
    x: jnp.ndarray,
508
509
510
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
511
    flatten_axis: int = -1,
512
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
513
514
515
516
    """
    Cast wrapper
    Return FP8 tensor
    """
517
518
519
520
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

521
522
    dq_dtype = dq_dtype or x.dtype

523
524
525
526
527
528
    if not DBiasQuantizePrimitive.enabled():
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
529
                flatten_axis=flatten_axis,
530
            )
531
532
533
534
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
535
536

    # TE/common doesn't support colwise only quantization yet
537
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
538
539
540
541
542
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
543
                flatten_axis=flatten_axis,
544
            )
545
546
547
548
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
549
550
551
552
    scale = jnp.empty((), jnp.float32)

    # TE/common dbias_quantize does not support 1x on arch < 100
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
553
        out, _ = _quantize_dbias_impl(
554
555
556
557
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
558
            flatten_axis=flatten_axis,
559
        )
560
        dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
561
562
563
564
        return out, dbias

    if quantizer is None:
        if is_dbias:
565
            return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        return x, None

    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DBiasQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
583
584
        q_layout=quantizer.q_layout.value,
        flatten_axis=flatten_axis,
585
        scale_dtype=quantizer.get_scale_dtype(),
586
        scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
    if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
602
603
604
605
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
606
    )
607
    return out, dbias.astype(dq_dtype)
608
609
610
611
612


def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
613
    flatten_axis: int = -1,
614
615
616
617
618
619
620
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
621
622
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
623
624
625
626

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
627
    out, _ = _quantize_dbias_impl(
628
629
        x,
        quantizer=quantizer,
630
        flatten_axis=flatten_axis,
631
632
633
634
635
636
637
638
    )
    return out


def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
639
    flatten_axis: int = -1,
640
641
642
643
644
645
646
647
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
648
649
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
650
651
652
653
654
655
656
657

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
658
659
    return _quantize_dbias_impl(
        dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
660
    )