quantization.py 23.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional
8
from packaging import version
9

10
import jax
11
12
import jax.numpy as jnp
from jax import dtypes
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
import transformer_engine_jax
17
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
22
    te_dtype_to_jax_dtype,
23
    jax_dtype_to_te_dtype,
24
25
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
26
    get_min_device_compute_capability,
27
    NamedSharding,
28
)
29
30
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
31
32
33
34
35
36
37
from ..quantize import (
    Quantizer,
    QuantizeLayout,
    DelayedScaleQuantizer,
    ScalingMode,
    compute_scale_from_amax,
)
38

39
40
41
42
43
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

44

45
__all__ = ["quantize", "quantize_dbias"]
46
47


48
class BaseDBiasQuantizePrimitive(BasePrimitive):
49
    """
50
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
51
52
    """

53
    name = "te_dbias_quantize_ffi"
54
    multiple_results = True
55
56
57
58
59
60
61
62
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
63
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
64
65
66
67
    inner_primitive = None
    outer_primitive = None

    @staticmethod
68
69
70
71
72
73
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
74
75
        q_layout,
        flatten_axis,
76
77
78
79
        scale_dtype,
        is_dbias,
        is_outer,
    ):
80
        """
81
        te_dbias_quantize_p abstract
82
83
84
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
85
        out_shape = x_aval.shape
86
        assert scale_aval is None or scale_aval.dtype == jnp.float32
87

88
89
90
91
92
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
93
94
95
96
97

        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
98
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
99

100
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
101
            if ScalingMode(scaling_mode).is_tensor_scaling():
102
103
104
105
106
107
108
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
109
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
110
111
112
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
113
114

        if is_dbias:
115
116
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
117
118
119
120
121
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
122
123
124
125
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
126
            )
127
128
129
130
131
132
133
134
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
135
136
137
138
139
140
141
142
143
144

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
145
146

    @staticmethod
147
    def outer_abstract(*args, **kwargs):
148
        """
149
        te_dbias_quantize_p outer primitive abstract
150
        """
151
152
153
154
155
156
157
158
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
159
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
160
161
162
163
164
165
166
167
168
169
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
170
171
        q_layout,
        flatten_axis,
172
173
174
175
176
177
178
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
179
        del out_dtype, scale_dtype, is_outer
180
        x_aval, scale_aval = ctx.avals_in
181
182
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
183
        return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)(
184
185
186
            ctx,
            x,
            scale,
187
            scaling_mode=scaling_mode.value,
188
189
            q_layout=q_layout,
            flatten_axis=flatten_axis,
190
191
            is_dbias=is_dbias,
        )
192
193

    @staticmethod
194
195
196
197
198
    def impl(
        x,
        scale,
        out_dtype,
        scaling_mode,
199
200
        q_layout,
        flatten_axis,
201
202
203
204
        scale_dtype,
        is_dbias,
        is_outer,
    ):
205
        """
206
        te_dbias_quantize_p implementation
207
        """
208
        del is_outer
209
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
210
211
212
213
214
215
216
217
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
218
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
219
220
221
222
            x,
            scale,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
223
224
            q_layout=q_layout,
            flatten_axis=flatten_axis,
225
226
227
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
228
        )
229
230
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
231
232
233
234
235
236
237
238
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
239
240
241
242
243
244
245
246
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
247
248

    @staticmethod
249
250
251
252
253
254
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
255
256
        q_layout,
        flatten_axis,
257
258
259
260
261
262
263
264
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
265
        check_valid_batch_dims(batch_dims)
266
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
267
268
269
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
270

271
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
272
        return (
273
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
274
275
276
277
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
278
279
                q_layout=q_layout,
                flatten_axis=flatten_axis,
280
281
282
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
            ),
283
284
            out_bdims,
        )
285
286

    @staticmethod
287
288
289
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
290
291
        q_layout,
        flatten_axis,
292
293
294
295
296
297
298
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
299
        del (out_dtype, result_infos, scale_dtype, is_outer)  # Unused.
300

301
        x_spec = get_padded_spec(arg_infos[0])
302
        scale_spec = get_padded_spec(arg_infos[1])
303
304
        out_sharding = NamedSharding(
            mesh,
305
            PartitionSpec(*x_spec),
306
            desc="BaseDBiasQuantizePrimitive.out_sharding",
307
        )
308
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
309
            if ScalingMode(scaling_mode).is_tensor_scaling():
310
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
311
312
313
314
315
316
317
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
318
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
319
        )
320
321
322

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
323
            mesh,
324
            PartitionSpec(*dbias_spec),
325
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
326
        )
327
328

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
329
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
330
            scale_inv_spec = amax_spec = scale_spec
331
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
332
333
334
335
336
337
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
338
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
339
        )
340
        amax_sharding = NamedSharding(
341
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
342
        )
343
        colwise_scale_inv_sharding = NamedSharding(
344
            mesh,
345
            PartitionSpec(*colwise_scale_inv_spec),
346
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
347
        )
348

349
350
351
352
353
354
355
356
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
357
358

    @staticmethod
359
360
361
    def partition(
        out_dtype,
        scaling_mode,
362
363
        q_layout,
        flatten_axis,
364
365
366
367
368
369
370
371
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
372

373
        x_spec = get_padded_spec(arg_infos[0])
374
        scale_spec = get_padded_spec(arg_infos[1])
375
376
        out_sharding = NamedSharding(
            mesh,
377
            PartitionSpec(*x_spec),
378
            desc="BaseDBiasQuantizePrimitive.out_sharding",
379
        )
380
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
381
            if ScalingMode(scaling_mode).is_tensor_scaling():
382
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
383
384
385
386
387
388
389
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
390
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
391
        )
392
393
394

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
395
            mesh,
396
            PartitionSpec(*dbias_spec),
397
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
398
        )
399
400

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
401
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
402
            scale_inv_spec = amax_spec = scale_spec
403
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
404
405
406
407
408
409
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
410
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
411
        )
412
        amax_sharding = NamedSharding(
413
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
414
        )
415
        colwise_scale_inv_sharding = NamedSharding(
416
            mesh,
417
            PartitionSpec(*colwise_scale_inv_spec),
418
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
419
        )
420

421
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
422
423
424
425
426
427
428
429
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
430

431
432
433
434
435
436
437
438
        def sharded_impl(x, scale):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
439
            ) = BaseDBiasQuantizePrimitive.impl(
440
441
442
443
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
444
445
                q_layout=q_layout,
                flatten_axis=flatten_axis,
446
447
448
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
449
            )
450

451
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
469
470
471

        return mesh, sharded_impl, out_shardings, arg_shardings

472
473
474
475
476
477
478
479
480
481
482
483
484
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
485
        del out_dtype, scale_dtype, is_outer, mesh, result_types
486
487

        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
488
            len(value_types[0].shape),
489
            unique_var="BaseDBiasQuantizePrimitive_i",
490
            flatten_axis=flatten_axis,
491
492
493
494
495
496
497
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
498
            if ScalingMode(scaling_mode).is_tensor_scaling():
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes
        else:
            colwise_out = ("j",)
            colwise_scale_inv = ("k",)

        dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
        amax = ("m",)

        return SdyShardingRule(
            (x_axes, ("…1",)),
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
            **scale_rules.factor_sizes,
        )

515

516
517
518
519
520
521
522
523
524
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
525
526


527
528
529
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
530
531
    if quantizer is None:
        return x
532
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
533
534


535
536
537
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
    assert flatten_axis < 0
    dtype = dtype or dx.dtype
538
    dbias = jnp.sum(
539
540
        dx.astype(jnp.float32),
        axis=tuple(range(dx.ndim + flatten_axis)),
541
542
        keepdims=False,
    )
543
    return dbias.astype(dtype)
544
545
546
547
548
549


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
550
    flatten_axis: int = -1,
551
552
553
):
    if quantizer is None:
        return x, None
554
555
556
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
557
558
559
    )


560
def _quantize_dbias_impl(
561
    x: jnp.ndarray,
562
563
564
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
565
    flatten_axis: int = -1,
566
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
567
568
569
570
    """
    Cast wrapper
    Return FP8 tensor
    """
571
572
573
574
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

575
576
    dq_dtype = dq_dtype or x.dtype

577
578
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
    if not PrimitiveClass.enabled():
579
580
581
582
583
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
584
                flatten_axis=flatten_axis,
585
            )
586
587
588
589
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
590
591

    # TE/common doesn't support colwise only quantization yet
592
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
593
594
595
596
597
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
598
                flatten_axis=flatten_axis,
599
            )
600
601
602
603
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
604
605
606
607
    scale = jnp.empty((), jnp.float32)

    # TE/common dbias_quantize does not support 1x on arch < 100
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
608
        out, _ = _quantize_dbias_impl(
609
610
611
612
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
613
            flatten_axis=flatten_axis,
614
        )
615
        dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
616
617
618
619
        return out, dbias

    if quantizer is None:
        if is_dbias:
620
            return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
621
622
        return x, None

623
624
625
626
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Globally reduce amax across all devices for current scaling so we have a single global scale.
        # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
        # until the tensor is dequantized (e.g. in the GEMM).
627
        amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
628
629
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)

630
631
632
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

633
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
634
    # It is faster to use 1x quantization for tensor scaling
635
636
637
638
639
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
        and quantizer.is_2x2x()
        and is_1x_kernel_supported
    )
640
641
642
643
644

    q_layout = quantizer.q_layout
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

645
646
647
648
649
650
651
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
652
    ) = PrimitiveClass.outer_primitive.bind(
653
654
655
656
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
657
        q_layout=q_layout.value,
658
        flatten_axis=flatten_axis,
659
660
661
662
663
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
664
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
665
666
        colwise_scale_inv = rowwise_scale_inv

667
668
669
670
671
672
673
674
675
        if q_layout == QuantizeLayout.ROWWISE:
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )

676
677
678
679
680
681
682
683
    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
684
685
686
687
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
688
    )
689
    return out, dbias.astype(dq_dtype)
690
691
692
693
694


def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
695
    flatten_axis: int = -1,
696
697
698
699
700
701
702
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
703
704
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
705
706
707
708

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
709
    out, _ = _quantize_dbias_impl(
710
711
        x,
        quantizer=quantizer,
712
        flatten_axis=flatten_axis,
713
714
715
716
717
718
719
720
    )
    return out


def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
721
    flatten_axis: int = -1,
722
723
724
725
726
727
728
729
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
730
731
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
732
733
734
735
736
737
738
739

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
740
741
    return _quantize_dbias_impl(
        dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
742
    )