quantization.py 22 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional
8
from packaging import version
9

10
import jax
11
12
import jax.numpy as jnp
from jax import dtypes
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
import transformer_engine_jax
17
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
22
    te_dtype_to_jax_dtype,
23
    jax_dtype_to_te_dtype,
24
25
26
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
27
)
28
29
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
30
from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
31

32
33
34
35
36
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

37

38
__all__ = ["quantize", "quantize_dbias"]
39
40


41
class DBiasQuantizePrimitive(BasePrimitive):
42
    """
43
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
44
45
    """

46
    name = "te_dbias_quantize_ffi"
47
    multiple_results = True
48
49
50
51
52
53
54
55
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
56
57
        9,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
58
59
60
61
    inner_primitive = None
    outer_primitive = None

    @staticmethod
62
63
64
65
66
67
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
68
69
        q_layout,
        flatten_axis,
70
71
72
73
74
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
75
        """
76
        te_dbias_quantize_p abstract
77
        """
78
        del scale_shapes
79
80
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
81
        out_shape = x_aval.shape
82
        assert scale_aval is None or scale_aval.dtype == jnp.float32
83

84
85
86
87
88
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
89
90
91
92
93

        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
94
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
95

96
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
97
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
98
99
100
101
102
103
104
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
105
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
106
107
108
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
109
110

        if is_dbias:
111
112
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
113
114
115
116
117
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
118
119
120
121
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
122
            )
123
124
125
126
127
128
129
130
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
131
132
133
134
135
136
137
138
139
140

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
141
142

    @staticmethod
143
    def outer_abstract(*args, **kwargs):
144
        """
145
        te_dbias_quantize_p outer primitive abstract
146
        """
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
166
167
        q_layout,
        flatten_axis,
168
169
170
171
172
173
174
175
176
177
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype, scale_shapes, is_outer
        x_aval, scale_aval = ctx.avals_in
178
179
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
180
181
182
183
        return ffi.ffi_lowering(DBiasQuantizePrimitive.name)(
            ctx,
            x,
            scale,
184
            scaling_mode=scaling_mode.value,
185
186
            q_layout=q_layout,
            flatten_axis=flatten_axis,
187
188
            is_dbias=is_dbias,
        )
189
190

    @staticmethod
191
192
193
194
195
    def impl(
        x,
        scale,
        out_dtype,
        scaling_mode,
196
197
        q_layout,
        flatten_axis,
198
199
200
201
202
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
203
        """
204
        te_dbias_quantize_p implementation
205
        """
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        del is_outer
        assert DBiasQuantizePrimitive.inner_primitive is not None
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
221
222
            q_layout=q_layout,
            flatten_axis=flatten_axis,
223
224
225
226
            scale_dtype=scale_dtype,
            scale_shapes=scale_shapes,
            is_dbias=is_dbias,
            is_outer=False,
227
        )
228
229
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
230
231
232
233
234
235
236
237
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
238
239
240
241
242
243
244
245
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
246
247

    @staticmethod
248
249
250
251
252
253
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
254
255
        q_layout,
        flatten_axis,
256
257
258
259
260
261
262
263
264
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
265
        check_valid_batch_dims(batch_dims)
266
267
268
269
        assert DBiasQuantizePrimitive.outer_primitive is not None
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
270

271
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
272
        return (
273
274
275
276
277
            DBiasQuantizePrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
278
279
                q_layout=q_layout,
                flatten_axis=flatten_axis,
280
281
282
283
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
            ),
284
285
            out_bdims,
        )
286
287

    @staticmethod
288
289
290
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
291
292
        q_layout,
        flatten_axis,
293
294
295
296
297
298
299
300
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
301
        del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer)  # Unused.
302
        x_spec = get_padded_spec(arg_infos[0])
303
        scale_spec = get_padded_spec(arg_infos[1])
304
305
        out_sharding = NamedSharding(
            mesh,
306
            PartitionSpec(*x_spec),
307
308
            desc="DBiasQuantizePrimitive.out_sharding",
        )
309
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
310
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
311
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
312
313
314
315
316
317
318
319
320
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
321
322
323

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
324
            mesh,
325
326
            PartitionSpec(*dbias_spec),
            desc="DBiasQuantizePrimitive.dbias_sharding",
327
        )
328
329

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
330
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
331
            scale_inv_spec = amax_spec = scale_spec
332
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
333
334
335
336
337
338
339
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
340
        )
341
342
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
343
        )
344
        colwise_scale_inv_sharding = NamedSharding(
345
            mesh,
346
347
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DBiasQuantizePrimitive.colwise_scale_inv",
348
        )
349

350
351
352
353
354
355
356
357
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
358
359

    @staticmethod
360
361
362
    def partition(
        out_dtype,
        scaling_mode,
363
364
        q_layout,
        flatten_axis,
365
366
367
368
369
370
371
372
373
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
374
        x_spec = get_padded_spec(arg_infos[0])
375
        scale_spec = get_padded_spec(arg_infos[1])
376
377
        out_sharding = NamedSharding(
            mesh,
378
            PartitionSpec(*x_spec),
379
380
            desc="DBiasQuantizePrimitive.out_sharding",
        )
381
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
382
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
383
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
384
385
386
387
388
389
390
391
392
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
393
394
395

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
396
            mesh,
397
398
            PartitionSpec(*dbias_spec),
            desc="DBiasQuantizePrimitive.dbias_sharding",
399
        )
400
401

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
402
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
403
            scale_inv_spec = amax_spec = scale_spec
404
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
405
406
407
408
409
410
411
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
412
        )
413
414
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
415
        )
416
        colwise_scale_inv_sharding = NamedSharding(
417
            mesh,
418
419
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DBiasQuantizePrimitive.colwise_scale_inv",
420
        )
421

422
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
423
424
425
426
427
428
429
430
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
431

432
433
434
435
436
437
438
439
440
441
442
443
444
        def sharded_impl(x, scale):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
            ) = DBiasQuantizePrimitive.impl(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
445
446
                q_layout=q_layout,
                flatten_axis=flatten_axis,
447
448
449
450
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                is_outer=True,
451
            )
452

453
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
471
472
473

        return mesh, sharded_impl, out_shardings, arg_shardings

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
        del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types

        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
            len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes
        else:
            colwise_out = ("j",)
            colwise_scale_inv = ("k",)

        dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
        amax = ("m",)

        return SdyShardingRule(
            (x_axes, ("…1",)),
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
            **scale_rules.factor_sizes,
        )

516

517
518
519
register_primitive(DBiasQuantizePrimitive)


520
521
522
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
523
524
    if quantizer is None:
        return x
525
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
526
527


528
529
530
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
    assert flatten_axis < 0
    dtype = dtype or dx.dtype
531
    dbias = jnp.sum(
532
533
        dx.astype(jnp.float32),
        axis=tuple(range(dx.ndim + flatten_axis)),
534
535
        keepdims=False,
    )
536
    return dbias.astype(dtype)
537
538
539
540
541
542


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
543
    flatten_axis: int = -1,
544
545
546
):
    if quantizer is None:
        return x, None
547
548
549
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
550
551
552
    )


553
def _quantize_dbias_impl(
554
    x: jnp.ndarray,
555
556
557
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
558
    flatten_axis: int = -1,
559
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
560
561
562
563
    """
    Cast wrapper
    Return FP8 tensor
    """
564
565
566
567
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

568
569
    dq_dtype = dq_dtype or x.dtype

570
571
572
573
574
575
    if not DBiasQuantizePrimitive.enabled():
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
576
                flatten_axis=flatten_axis,
577
            )
578
579
580
581
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
582
583

    # TE/common doesn't support colwise only quantization yet
584
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
585
586
587
588
589
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
590
                flatten_axis=flatten_axis,
591
            )
592
593
594
595
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
596
597
598
599
    scale = jnp.empty((), jnp.float32)

    # TE/common dbias_quantize does not support 1x on arch < 100
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
600
        out, _ = _quantize_dbias_impl(
601
602
603
604
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
605
            flatten_axis=flatten_axis,
606
        )
607
        dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
608
609
610
611
        return out, dbias

    if quantizer is None:
        if is_dbias:
612
            return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        return x, None

    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DBiasQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
630
631
        q_layout=quantizer.q_layout.value,
        flatten_axis=flatten_axis,
632
        scale_dtype=quantizer.get_scale_dtype(),
633
        scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
634
635
636
637
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
638
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
639
640
641
642
643
644
645
646
647
648
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
649
650
651
652
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
653
    )
654
    return out, dbias.astype(dq_dtype)
655
656
657
658
659


def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
660
    flatten_axis: int = -1,
661
662
663
664
665
666
667
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
668
669
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
670
671
672
673

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
674
    out, _ = _quantize_dbias_impl(
675
676
        x,
        quantizer=quantizer,
677
        flatten_axis=flatten_axis,
678
679
680
681
682
683
684
685
    )
    return out


def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
686
    flatten_axis: int = -1,
687
688
689
690
691
692
693
694
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
695
696
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
697
698
699
700
701
702
703
704

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
705
706
    return _quantize_dbias_impl(
        dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
707
    )