quantization.py 22.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional
8
from packaging import version
9

10
import jax
11
12
import jax.numpy as jnp
from jax import dtypes
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
import transformer_engine_jax
17
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
22
    te_dtype_to_jax_dtype,
23
    jax_dtype_to_te_dtype,
24
25
26
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
27
)
28
29
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
30
from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
31

32
33
34
35
36
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

37

38
__all__ = ["quantize", "quantize_dbias"]
39
40


41
class DBiasQuantizePrimitive(BasePrimitive):
42
    """
43
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
44
45
    """

46
    name = "te_dbias_quantize_ffi"
47
    multiple_results = True
48
49
50
51
52
53
54
55
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
56
57
        9,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
58
59
60
61
    inner_primitive = None
    outer_primitive = None

    @staticmethod
62
63
64
65
66
67
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
68
69
        q_layout,
        flatten_axis,
70
71
72
73
74
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
75
        """
76
        te_dbias_quantize_p abstract
77
        """
78
        del scale_shapes
79
80
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
81
        out_shape = x_aval.shape
82
        assert scale_aval is None or scale_aval.dtype == jnp.float32
83

84
85
86
87
88
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
89
90
91
92
93

        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
94
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
95

96
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
97
98
99
100
            if scaling_mode in (
                ScalingMode.DELAYED_TENSOR_SCALING.value,
                ScalingMode.CURRENT_TENSOR_SCALING.value,
            ):
101
102
103
104
105
106
107
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
108
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
109
110
111
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
112
113

        if is_dbias:
114
115
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
116
117
118
119
120
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
121
122
123
124
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
125
            )
126
127
128
129
130
131
132
133
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
134
135
136
137
138
139
140
141
142
143

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
144
145

    @staticmethod
146
    def outer_abstract(*args, **kwargs):
147
        """
148
        te_dbias_quantize_p outer primitive abstract
149
        """
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
169
170
        q_layout,
        flatten_axis,
171
172
173
174
175
176
177
178
179
180
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype, scale_shapes, is_outer
        x_aval, scale_aval = ctx.avals_in
181
182
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
183
184
185
186
        return ffi.ffi_lowering(DBiasQuantizePrimitive.name)(
            ctx,
            x,
            scale,
187
            scaling_mode=scaling_mode.value,
188
189
            q_layout=q_layout,
            flatten_axis=flatten_axis,
190
191
            is_dbias=is_dbias,
        )
192
193

    @staticmethod
194
195
196
197
198
    def impl(
        x,
        scale,
        out_dtype,
        scaling_mode,
199
200
        q_layout,
        flatten_axis,
201
202
203
204
205
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
206
        """
207
        te_dbias_quantize_p implementation
208
        """
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        del is_outer
        assert DBiasQuantizePrimitive.inner_primitive is not None
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
224
225
            q_layout=q_layout,
            flatten_axis=flatten_axis,
226
227
228
229
            scale_dtype=scale_dtype,
            scale_shapes=scale_shapes,
            is_dbias=is_dbias,
            is_outer=False,
230
        )
231
232
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
233
234
235
236
237
238
239
240
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
241
242
243
244
245
246
247
248
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
249
250

    @staticmethod
251
252
253
254
255
256
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
257
258
        q_layout,
        flatten_axis,
259
260
261
262
263
264
265
266
267
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
268
        check_valid_batch_dims(batch_dims)
269
270
271
272
        assert DBiasQuantizePrimitive.outer_primitive is not None
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
273

274
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
275
        return (
276
277
278
279
280
            DBiasQuantizePrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
281
282
                q_layout=q_layout,
                flatten_axis=flatten_axis,
283
284
285
286
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
            ),
287
288
            out_bdims,
        )
289
290

    @staticmethod
291
292
293
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
294
295
        q_layout,
        flatten_axis,
296
297
298
299
300
301
302
303
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
304
        del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer)  # Unused.
305
306
307
308
309

        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Current tensor scaling is not yet supported for multi-GPU partitioning."

310
        x_spec = get_padded_spec(arg_infos[0])
311
        scale_spec = get_padded_spec(arg_infos[1])
312
313
        out_sharding = NamedSharding(
            mesh,
314
            PartitionSpec(*x_spec),
315
316
            desc="DBiasQuantizePrimitive.out_sharding",
        )
317
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
318
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
319
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
320
321
322
323
324
325
326
327
328
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
329
330
331

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
332
            mesh,
333
334
            PartitionSpec(*dbias_spec),
            desc="DBiasQuantizePrimitive.dbias_sharding",
335
        )
336
337

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
338
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
339
            scale_inv_spec = amax_spec = scale_spec
340
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
341
342
343
344
345
346
347
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
348
        )
349
350
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
351
        )
352
        colwise_scale_inv_sharding = NamedSharding(
353
            mesh,
354
355
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DBiasQuantizePrimitive.colwise_scale_inv",
356
        )
357

358
359
360
361
362
363
364
365
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
366
367

    @staticmethod
368
369
370
    def partition(
        out_dtype,
        scaling_mode,
371
372
        q_layout,
        flatten_axis,
373
374
375
376
377
378
379
380
381
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
382
383
384
385
386

        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Current tensor scaling is not yet supported for multi-GPU partitioning."

387
        x_spec = get_padded_spec(arg_infos[0])
388
        scale_spec = get_padded_spec(arg_infos[1])
389
390
        out_sharding = NamedSharding(
            mesh,
391
            PartitionSpec(*x_spec),
392
393
            desc="DBiasQuantizePrimitive.out_sharding",
        )
394
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
395
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
396
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
397
398
399
400
401
402
403
404
405
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
406
407
408

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
409
            mesh,
410
411
            PartitionSpec(*dbias_spec),
            desc="DBiasQuantizePrimitive.dbias_sharding",
412
        )
413
414

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
415
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
416
            scale_inv_spec = amax_spec = scale_spec
417
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
418
419
420
421
422
423
424
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
425
        )
426
427
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
428
        )
429
        colwise_scale_inv_sharding = NamedSharding(
430
            mesh,
431
432
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DBiasQuantizePrimitive.colwise_scale_inv",
433
        )
434

435
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
436
437
438
439
440
441
442
443
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
444

445
446
447
448
449
450
451
452
453
454
455
456
457
        def sharded_impl(x, scale):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
            ) = DBiasQuantizePrimitive.impl(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
458
459
                q_layout=q_layout,
                flatten_axis=flatten_axis,
460
461
462
463
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                is_outer=True,
464
            )
465

466
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
484
485
486

        return mesh, sharded_impl, out_shardings, arg_shardings

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
        del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types

        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
            len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes
        else:
            colwise_out = ("j",)
            colwise_scale_inv = ("k",)

        dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
        amax = ("m",)

        return SdyShardingRule(
            (x_axes, ("…1",)),
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
            **scale_rules.factor_sizes,
        )

529

530
531
532
register_primitive(DBiasQuantizePrimitive)


533
534
535
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
536
537
    if quantizer is None:
        return x
538
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
539
540


541
542
543
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
    assert flatten_axis < 0
    dtype = dtype or dx.dtype
544
    dbias = jnp.sum(
545
546
        dx.astype(jnp.float32),
        axis=tuple(range(dx.ndim + flatten_axis)),
547
548
        keepdims=False,
    )
549
    return dbias.astype(dtype)
550
551
552
553
554
555


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
556
    flatten_axis: int = -1,
557
558
559
):
    if quantizer is None:
        return x, None
560
561
562
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
563
564
565
    )


566
def _quantize_dbias_impl(
567
    x: jnp.ndarray,
568
569
570
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
571
    flatten_axis: int = -1,
572
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
573
574
575
576
    """
    Cast wrapper
    Return FP8 tensor
    """
577
578
579
580
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

581
582
    dq_dtype = dq_dtype or x.dtype

583
584
585
586
587
588
    if not DBiasQuantizePrimitive.enabled():
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
589
                flatten_axis=flatten_axis,
590
            )
591
592
593
594
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
595
596

    # TE/common doesn't support colwise only quantization yet
597
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
598
599
600
601
602
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
603
                flatten_axis=flatten_axis,
604
            )
605
606
607
608
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
609
610
611
612
    scale = jnp.empty((), jnp.float32)

    # TE/common dbias_quantize does not support 1x on arch < 100
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
613
        out, _ = _quantize_dbias_impl(
614
615
616
617
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
618
            flatten_axis=flatten_axis,
619
        )
620
        dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
621
622
623
624
        return out, dbias

    if quantizer is None:
        if is_dbias:
625
            return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        return x, None

    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DBiasQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
643
644
        q_layout=quantizer.q_layout.value,
        flatten_axis=flatten_axis,
645
        scale_dtype=quantizer.get_scale_dtype(),
646
        scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
647
648
649
650
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
651
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
652
653
654
655
656
657
658
659
660
661
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
662
663
664
665
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
666
    )
667
    return out, dbias.astype(dq_dtype)
668
669
670
671
672


def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
673
    flatten_axis: int = -1,
674
675
676
677
678
679
680
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
681
682
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
683
684
685
686

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
687
    out, _ = _quantize_dbias_impl(
688
689
        x,
        quantizer=quantizer,
690
        flatten_axis=flatten_axis,
691
692
693
694
695
696
697
698
    )
    return out


def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
699
    flatten_axis: int = -1,
700
701
702
703
704
705
706
707
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
708
709
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
710
711
712
713
714
715
716
717

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
718
719
    return _quantize_dbias_impl(
        dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
720
    )