quantization.py 22.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional
8
from packaging import version
9

10
import jax
11
12
import jax.numpy as jnp
from jax import dtypes
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
import transformer_engine_jax
17
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
22
    te_dtype_to_jax_dtype,
23
    jax_dtype_to_te_dtype,
24
25
26
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
27
)
28
29
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
30
31
32
33
34
35
36
from ..quantize import (
    Quantizer,
    QuantizeLayout,
    DelayedScaleQuantizer,
    ScalingMode,
    compute_scale_from_amax,
)
37

38
39
40
41
42
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

43

44
__all__ = ["quantize", "quantize_dbias"]
45
46


47
class BaseDBiasQuantizePrimitive(BasePrimitive):
48
    """
49
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
50
51
    """

52
    name = "te_dbias_quantize_ffi"
53
    multiple_results = True
54
55
56
57
58
59
60
61
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
62
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
63
64
65
66
    inner_primitive = None
    outer_primitive = None

    @staticmethod
67
68
69
70
71
72
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
73
74
        q_layout,
        flatten_axis,
75
76
77
78
        scale_dtype,
        is_dbias,
        is_outer,
    ):
79
        """
80
        te_dbias_quantize_p abstract
81
82
83
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
84
        out_shape = x_aval.shape
85
        assert scale_aval is None or scale_aval.dtype == jnp.float32
86

87
88
89
90
91
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
92
93
94
95
96

        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
97
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
98

99
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
100
            if ScalingMode(scaling_mode).is_tensor_scaling():
101
102
103
104
105
106
107
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
108
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
109
110
111
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
112
113

        if is_dbias:
114
115
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
116
117
118
119
120
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
121
122
123
124
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
125
            )
126
127
128
129
130
131
132
133
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
134
135
136
137
138
139
140
141
142
143

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
144
145

    @staticmethod
146
    def outer_abstract(*args, **kwargs):
147
        """
148
        te_dbias_quantize_p outer primitive abstract
149
        """
150
151
152
153
154
155
156
157
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
158
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
159
160
161
162
163
164
165
166
167
168
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
169
170
        q_layout,
        flatten_axis,
171
172
173
174
175
176
177
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
178
        del out_dtype, scale_dtype, is_outer
179
        x_aval, scale_aval = ctx.avals_in
180
181
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
182
        return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)(
183
184
185
            ctx,
            x,
            scale,
186
            scaling_mode=scaling_mode.value,
187
188
            q_layout=q_layout,
            flatten_axis=flatten_axis,
189
190
            is_dbias=is_dbias,
        )
191
192

    @staticmethod
193
194
195
196
197
    def impl(
        x,
        scale,
        out_dtype,
        scaling_mode,
198
199
        q_layout,
        flatten_axis,
200
201
202
203
        scale_dtype,
        is_dbias,
        is_outer,
    ):
204
        """
205
        te_dbias_quantize_p implementation
206
        """
207
        del is_outer
208
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
209
210
211
212
213
214
215
216
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
217
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
218
219
220
221
            x,
            scale,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
222
223
            q_layout=q_layout,
            flatten_axis=flatten_axis,
224
225
226
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
227
        )
228
229
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
230
231
232
233
234
235
236
237
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
238
239
240
241
242
243
244
245
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
246
247

    @staticmethod
248
249
250
251
252
253
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
254
255
        q_layout,
        flatten_axis,
256
257
258
259
260
261
262
263
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
264
        check_valid_batch_dims(batch_dims)
265
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
266
267
268
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
269

270
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
271
        return (
272
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
273
274
275
276
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
277
278
                q_layout=q_layout,
                flatten_axis=flatten_axis,
279
280
281
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
            ),
282
283
            out_bdims,
        )
284
285

    @staticmethod
286
287
288
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
289
290
        q_layout,
        flatten_axis,
291
292
293
294
295
296
297
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
298
        del (out_dtype, result_infos, scale_dtype, is_outer)  # Unused.
299

300
        x_spec = get_padded_spec(arg_infos[0])
301
        scale_spec = get_padded_spec(arg_infos[1])
302
303
        out_sharding = NamedSharding(
            mesh,
304
            PartitionSpec(*x_spec),
305
            desc="BaseDBiasQuantizePrimitive.out_sharding",
306
        )
307
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
308
            if ScalingMode(scaling_mode).is_tensor_scaling():
309
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
310
311
312
313
314
315
316
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
317
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
318
        )
319
320
321

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
322
            mesh,
323
            PartitionSpec(*dbias_spec),
324
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
325
        )
326
327

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
328
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
329
            scale_inv_spec = amax_spec = scale_spec
330
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
331
332
333
334
335
336
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
337
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
338
        )
339
        amax_sharding = NamedSharding(
340
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
341
        )
342
        colwise_scale_inv_sharding = NamedSharding(
343
            mesh,
344
            PartitionSpec(*colwise_scale_inv_spec),
345
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
346
        )
347

348
349
350
351
352
353
354
355
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
356
357

    @staticmethod
358
359
360
    def partition(
        out_dtype,
        scaling_mode,
361
362
        q_layout,
        flatten_axis,
363
364
365
366
367
368
369
370
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
371

372
        x_spec = get_padded_spec(arg_infos[0])
373
        scale_spec = get_padded_spec(arg_infos[1])
374
375
        out_sharding = NamedSharding(
            mesh,
376
            PartitionSpec(*x_spec),
377
            desc="BaseDBiasQuantizePrimitive.out_sharding",
378
        )
379
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
380
            if ScalingMode(scaling_mode).is_tensor_scaling():
381
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
382
383
384
385
386
387
388
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
389
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
390
        )
391
392
393

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
394
            mesh,
395
            PartitionSpec(*dbias_spec),
396
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
397
        )
398
399

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
400
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
401
            scale_inv_spec = amax_spec = scale_spec
402
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
403
404
405
406
407
408
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
409
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
410
        )
411
        amax_sharding = NamedSharding(
412
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
413
        )
414
        colwise_scale_inv_sharding = NamedSharding(
415
            mesh,
416
            PartitionSpec(*colwise_scale_inv_spec),
417
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
418
        )
419

420
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
421
422
423
424
425
426
427
428
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
429

430
431
432
433
434
435
436
437
        def sharded_impl(x, scale):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
438
            ) = BaseDBiasQuantizePrimitive.impl(
439
440
441
442
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
443
444
                q_layout=q_layout,
                flatten_axis=flatten_axis,
445
446
447
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
448
            )
449

450
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
468
469
470

        return mesh, sharded_impl, out_shardings, arg_shardings

471
472
473
474
475
476
477
478
479
480
481
482
483
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
484
        del out_dtype, scale_dtype, is_outer, mesh, result_types
485
486

        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
487
            len(value_types[0].shape),
488
            unique_var="BaseDBiasQuantizePrimitive_i",
489
            flatten_axis=flatten_axis,
490
491
492
493
494
495
496
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
497
            if ScalingMode(scaling_mode).is_tensor_scaling():
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes
        else:
            colwise_out = ("j",)
            colwise_scale_inv = ("k",)

        dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
        amax = ("m",)

        return SdyShardingRule(
            (x_axes, ("…1",)),
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
            **scale_rules.factor_sizes,
        )

514

515
516
517
518
519
520
521
522
523
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
524
525


526
527
528
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
529
530
    if quantizer is None:
        return x
531
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
532
533


534
535
536
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
    assert flatten_axis < 0
    dtype = dtype or dx.dtype
537
    dbias = jnp.sum(
538
539
        dx.astype(jnp.float32),
        axis=tuple(range(dx.ndim + flatten_axis)),
540
541
        keepdims=False,
    )
542
    return dbias.astype(dtype)
543
544
545
546
547
548


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
549
    flatten_axis: int = -1,
550
551
552
):
    if quantizer is None:
        return x, None
553
554
555
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
556
557
558
    )


559
def _quantize_dbias_impl(
560
    x: jnp.ndarray,
561
562
563
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
564
    flatten_axis: int = -1,
565
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
566
567
568
569
    """
    Cast wrapper
    Return FP8 tensor
    """
570
571
572
573
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

574
575
    dq_dtype = dq_dtype or x.dtype

576
577
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
    if not PrimitiveClass.enabled():
578
579
580
581
582
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
583
                flatten_axis=flatten_axis,
584
            )
585
586
587
588
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
589
590

    # TE/common doesn't support colwise only quantization yet
591
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
592
593
594
595
596
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
597
                flatten_axis=flatten_axis,
598
            )
599
600
601
602
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
603
604
605
606
    scale = jnp.empty((), jnp.float32)

    # TE/common dbias_quantize does not support 1x on arch < 100
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
607
        out, _ = _quantize_dbias_impl(
608
609
610
611
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
612
            flatten_axis=flatten_axis,
613
        )
614
        dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
615
616
617
618
        return out, dbias

    if quantizer is None:
        if is_dbias:
619
            return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
620
621
        return x, None

622
623
624
625
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Globally reduce amax across all devices for current scaling so we have a single global scale.
        # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
        # until the tensor is dequantized (e.g. in the GEMM).
626
        amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
627
628
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)

629
630
631
632
633
634
635
636
637
638
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
639
    ) = PrimitiveClass.outer_primitive.bind(
640
641
642
643
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
644
645
        q_layout=quantizer.q_layout.value,
        flatten_axis=flatten_axis,
646
647
648
649
650
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
651
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
652
653
654
655
656
657
658
659
660
661
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
662
663
664
665
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
666
    )
667
    return out, dbias.astype(dq_dtype)
668
669
670
671
672


def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
673
    flatten_axis: int = -1,
674
675
676
677
678
679
680
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
681
682
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
683
684
685
686

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
687
    out, _ = _quantize_dbias_impl(
688
689
        x,
        quantizer=quantizer,
690
        flatten_axis=flatten_axis,
691
692
693
694
695
696
697
698
    )
    return out


def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
699
    flatten_axis: int = -1,
700
701
702
703
704
705
706
707
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
708
709
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
710
711
712
713
714
715
716
717

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
718
719
    return _quantize_dbias_impl(
        dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
720
    )