helper.py 33.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Config module for quantization metadata management

This module provides configuration and helper functions for managing quantization metadata
in JAX, including support for different scaling modes and datatypes.
"""
10

11
from abc import ABC, abstractmethod
12
from contextlib import contextmanager
13
from dataclasses import dataclass
14
from enum import Enum
15
import hashlib
16
17
from typing import Optional, Tuple, Dict, Union, Sequence, Type, List
from functools import reduce, lru_cache
Alp Dener's avatar
Alp Dener committed
18
import operator
19
20
21
from importlib.metadata import version as get_pkg_version
import warnings
from packaging.version import Version as PkgVersion
22
23
24
25
26

import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict

27
from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version
28
29
30
31
32
33
34
35
from transformer_engine.common.recipe import (
    Recipe,
    DelayedScaling,
    Format,
    MXFP8BlockScaling,
    Float8CurrentScaling,
    NVFP4BlockScaling,
)
36
37
38
from transformer_engine.jax.sharding import (
    global_shard_guard,
    MeshResource,
39
    get_num_devices_in_mesh,
40
41
42
43
44
    get_all_mesh_axes,
    with_sharding_constraint,
)

from .metadata import QuantizeMeta
45
from .scaling_modes import ScalingMode
46
from .device_utils import get_device_compute_capability
47

48
__all__ = [
49
    "get_quantize_config",
50
    "get_quantize_config_with_recipe",
51
    "autocast",
52
53
    "fp8_autocast",
    "is_fp8_available",
54
55
56
    "is_scaling_mode_supported",
    "get_supported_scaling_modes",
    "get_supported_quantization_recipes",
57
    "update_collections",
Alp Dener's avatar
Alp Dener committed
58
59
    "apply_padding_to_scale_inv",
    "remove_padding_from_scale_inv",
60
    "NVTE_FP8_COLLECTION_NAME",
61
    "TensorSource",
62
]
63

64
65
_is_scaling_mode_supported = None
_reason_for_no_scaling_mode = ""
66
67
Collection = Union[Dict, FrozenDict]

68
69
NVTE_FP8_COLLECTION_NAME = "fp8_metas"

70

71
72
73
74
75
76
77
78
79
80
@lru_cache(maxsize=None)
def _jax_version_meet_requirement(version: str):
    """
    Helper function checking if required JAX version is available
    """
    jax_version = PkgVersion(get_pkg_version("jax"))
    jax_version_required = PkgVersion(version)
    return jax_version >= jax_version_required


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
    """Check if delayed scaling FP8 is supported on the given GPU architecture.

    Args:
        gpu_arch: The GPU architecture version

    Returns:
        A tuple of (bool, str) indicating support and any error message
    """
    if gpu_arch < 89:  # pre-ada
        return False, "Device compute capability 8.9 or higher required for FP8 execution."
    if get_cublasLt_version() < 120103:
        return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
    if get_cuda_version() < 12010:
        return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
    return True, ""


def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]:
    """Check if block scaling FP8 is supported on the given GPU architecture.

    Args:
        gpu_arch: The GPU architecture version

    Returns:
        A tuple of (bool, str) indicating support and any error message
    """
    if gpu_arch < 99:  # pre-blackwell
        return False, "Device compute capability 9.9 or higher required for MXFP8 execution."
    if get_cublasLt_version() < 120800:
        return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution."
112
    if get_cuda_version() < 12080:
113
        return False, "Cuda version 12.8 or higher required for MXFP8 execution."
114
    if not _jax_version_meet_requirement("0.5.3"):
115
116
117
118
        return False, "Jax version 0.5.3 or higher required for MXFP8 execution."
    return True, ""


119
120
121
122
123
124
125
126
127
128
129
130
131
132
def _check_fp4_support(gpu_arch) -> Tuple[bool, str]:
    """Check if FP4 is supported for the given GPU architecture."""
    if gpu_arch < 100:  # pre-blackwell
        return False, "Device compute capability 10.0 or higher required for NVFP4 execution."
    if get_cublasLt_version() < 120800:
        return False, "CublasLt version 12.8.0 or higher required for NVFP4 execution."
    if get_cuda_version() < 12080:
        return False, "Cuda version 12.8 or higher required for NVFP4 execution."
    if not _jax_version_meet_requirement("0.5.3"):
        return False, "Jax version 0.5.3 or higher required for NVFP4 execution."
    return True, ""


def _check_scaling_support(scaling_mode: ScalingMode, gpu_id: int) -> Tuple[bool, str]:
133
134
135
136
137
138
139
140
141
142
    """Check if FP8 is supported for the given scaling mode and GPU.

    Args:
        scaling_mode: The scaling mode to check support for
        gpu_id: The ID of the GPU to check

    Returns:
        A tuple of (bool, str) indicating support and any error message
    """
    gpu_arch = get_device_compute_capability(gpu_id)
143
    if scaling_mode.is_tensor_scaling():
144
        return _check_delayed_scaling_fp8_support(gpu_arch)
145
    if scaling_mode.is_mxfp8_scaling:
146
        return _check_block_scaling_fp8_support(gpu_arch)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    if scaling_mode.is_nvfp4_scaling:
        return _check_fp4_support(gpu_arch)
    return (True, "")  # NO_SCALING is always supported


def is_scaling_mode_supported(
    scaling_mode=ScalingMode.NO_SCALING,
    gpu_id=None,
) -> Tuple[bool, str]:
    """Check if the given scaling mode is available for the given GPU."""
    if gpu_id is not None:
        return _check_scaling_support(scaling_mode, gpu_id)

    global _is_scaling_mode_supported, _reason_for_no_scaling_mode
    if _is_scaling_mode_supported is None:
        _is_scaling_mode_supported = {}
        _reason_for_no_scaling_mode = {}
    if scaling_mode not in _is_scaling_mode_supported:
        _is_scaling_mode_supported[scaling_mode] = True
        _reason_for_no_scaling_mode[scaling_mode] = ""
        for local_gpu_id in range(len(jax.local_devices())):
            ret, msg = _check_scaling_support(scaling_mode, local_gpu_id)
            if ret is False:
                _is_scaling_mode_supported[scaling_mode] = ret
                _reason_for_no_scaling_mode[scaling_mode] = msg
                return ret, msg
    return _is_scaling_mode_supported[scaling_mode], _reason_for_no_scaling_mode[scaling_mode]
174
175
176


def is_fp8_available(
177
    scaling_mode=ScalingMode.DELAYED_TENSOR_SCALING,
178
179
180
181
182
183
184
185
186
187
188
    gpu_id=None,
) -> Tuple[bool, str]:
    """Check if FP8 is available for the given scaling mode and GPU.

    Args:
        scaling_mode: The scaling mode to check availability for (default: DELAYED_TENSOR_SCALING)
        gpu_id: Optional GPU ID to check specific device (default: None)

    Returns:
        A tuple of (bool, str) indicating availability and any error message
    """
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    warnings.warn(
        "is_fp8_available is deprecated. Use is_scaling_mode_supported instead.", DeprecationWarning
    )
    return is_scaling_mode_supported(scaling_mode=scaling_mode, gpu_id=gpu_id)


# TODO(Phuong): make the infrastruture to support NO_SCALING
def get_supported_scaling_modes() -> List[ScalingMode]:
    """Get all supported quantization scaling modes."""
    return [
        scaling_mode
        for scaling_mode in ScalingMode
        if is_scaling_mode_supported(scaling_mode=scaling_mode)[0]
        and scaling_mode != ScalingMode.NO_SCALING
    ]


206
def get_supported_quantization_recipes() -> List[Recipe]:
207
208
209
210
    """Get all supported quantization recipes."""
    # We don't support all the recipes TE/Common supports yet
    # return [get_quantize_config_class(recipe)() for recipe in recipe.Recipe.__subclasses__()]
    all_recipes = [
211
212
213
214
        DelayedScaling(),
        Float8CurrentScaling(),
        MXFP8BlockScaling(),
        NVFP4BlockScaling(),
215
216
217
218
    ]
    return [
        recipe for recipe in all_recipes if get_quantize_config_class(recipe)().is_supported()[0]
    ]
219
220


221
def _format2dtypes(format_: Format):
222
223
224
225
226
227
228
229
    """Convert recipe.Format.dtype to corresponding JAX dtypes.

    Args:
        format_: The FP8 format to convert

    Returns:
        A tuple of (forward_dtype, backward_dtype) for the given format
    """
230
    if format_ == Format.E4M3:
231
        return jnp.float8_e4m3fn, jnp.float8_e4m3fn
232
    if format_ == Format.E5M2:
233
        return jnp.float8_e5m2, jnp.float8_e5m2
234
    if format_ == Format.HYBRID:
235
        return jnp.float8_e4m3fn, jnp.float8_e5m2
236
    if format_ == Format.E2M1:
237
        return jnp.float4_e2m1fn, jnp.float4_e2m1fn
238
239
240
    return jnp.bfloat16, jnp.bfloat16


241
242
243
244
245
246
247
248
249
250
251
class TensorSource(Enum):
    """Enumeration for where a tensor's data comes from."""

    # Input data
    X = 0
    # Model parameters
    KERNEL = 1
    # Gradients in the backward pass
    DGRAD = 2


252
253
254
255
256
257
258
259
260
261
262
263
class AmaxComputeAlgo(Enum):
    """Enumeration for AMAX computation algorithms.

    Attributes:
        MAX: Use maximum value for AMAX computation
        MOST_RECENT: Use most recent value for AMAX computation
    """

    MAX = "max"
    MOST_RECENT = "most_recent"


264
265
@dataclass
class BaseQuantizeConfig(ABC):
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    """Configuration class for quantization settings.

    This class manages global quantization settings including FP8 formats,
    scaling modes, and accumulation settings.

    Attributes:
        INITIALIZED: Whether the config has been initialized
        MARGIN: Margin value for quantization
        COLLECTION_NAME: Name of the collection for quantization metadata
        FWD_DTYPE: Forward pass data type
        BWD_DTYPE: Backward pass data type
        FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
        FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
        FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
280
        INFERENCE_MODE: Whether to enable optimization for inference
281
282
283
284
285
286
        AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
        AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
    """

    INITIALIZED = False
    MARGIN: float = 0.0
287
    COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME
288
289
    FWD_DTYPE: DType = None
    BWD_DTYPE: DType = None
290
291
292
    FP8_2X_ACC_FPROP: bool = False
    FP8_2X_ACC_DGRAD: bool = False
    FP8_2X_ACC_WGRAD: bool = False
293
    INFERENCE_MODE: bool = False
294
295

    # DelayedScaling
296
    # TODO(Phuong): move these two into DelayedScalingQuantizeConfig
297
298
299
    AMAX_HISTORY_LEN: int = 1024
    AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX

300
    def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
301
        """Initialize the quantization configuration from a given recipe.
302
303
304
305
306

        Args:
            fp8_recipe: The FP8 recipe to use for initialization
        """
        self.INITIALIZED = True
307
        self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp8_format)
308
309

    def is_fp8_enabled(self) -> bool:
310
311
312
313
314
        """Check if FP8 quantization is enabled.

        Returns:
            bool: True if quantization is enabled, False otherwise
        """
315
        return self.INITIALIZED
316

317
318
319
    @abstractmethod
    def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
        """Gets the scaling mode for a specific tensor's usage type.
320
321

        Args:
322
323
324
325
326
327
            tensor_source: The usage type for which to get the scaling mode.

        Returns:
            The scaling mode for the specified usage type.
        """

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    @abstractmethod
    def get_quantize_flax_meta(
        self,
        module,
        collection_name: str,
        postfix: str,
        tensor_source: TensorSource,
        quantizer_name: str,
    ) -> QuantizeMeta:
        """Get the quantization metadata for a given Flax module.

        Args:
            module: The Flax module to get metadata for
            collection_name: The name of the collection to store metadata in
            postfix: Postfix to append to metadata names
            tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
            quantizer_name: The name of the quantizer within the module
        Returns:
            The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
        """

349
350
351
352
353
354
    def is_supported(self) -> tuple[bool, str]:
        """Check if this QuantizeConfig class is supported on the available devices.

        Returns:
            bool: True if the class is supported, False otherwise
            str: Reason for being unsupported, if applicable.
355
        """
356
357
358
359
360

        x_scaling_mode = self.get_scaling_mode(TensorSource.X)
        kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL)
        grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD)
        for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]:
361
            is_supported, reason = is_scaling_mode_supported(scaling_mode=scaling_mode)
362
363
364
365
366
367
368
369
            if not is_supported:
                return is_supported, reason
        return True, None


class NoOpQuantizeConfig(BaseQuantizeConfig):
    """Configuration class higher-precision non-quantized operation."""

370
    def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
371
372
373
374
375
376
377
378
379
380
        """Initialize no-op configuration."""
        raise NotImplementedError(
            "NoOpQuantizeConfig cannot be initialize from a recipe as it represents"
            " higher-precision when no quantized recipe is set."
        )

    def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
        """Gets the scaling mode for a specific tensor's usage type."""
        return ScalingMode.NO_SCALING

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    def get_quantize_flax_meta(
        self,
        module,
        collection_name: str,
        postfix: str,
        tensor_source: TensorSource,
        quantizer_name: str,
    ) -> QuantizeMeta:
        """Get the quantization metadata for a given Flax module.

        Args:
            module: The Flax module to get metadata for
            collection_name: The name of the collection to store metadata in
            postfix: Postfix to append to metadata names
            tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
            quantizer_name: The name of the quantizer within the module
        Returns:
            The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
        """
        return QuantizeMeta()

402
403

class DelayedScalingQuantizeConfig(BaseQuantizeConfig):
404
405
406
407
408
409
    """Configuration class for delayed scaling FP8 recipe.

    This class provides specific initialization and finalization for delayed scaling
    FP8 quantization mode.
    """

410
    def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
411
412
413
414
415
416
417
418
        """Initialize delayed scaling FP8 configuration.

        Args:
            fp8_recipe: The FP8 recipe to use for initialization

        Raises:
            AssertionError: If recipe parameters are not supported
        """
419
        super().initialize_from_recipe(fp8_recipe)
420
        self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
421

422
423
424
425
426
427
428
429
430
        assert fp8_recipe.amax_compute_algo in [
            "max",
            "most_recent",
        ], "DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
        assert (
            fp8_recipe.scaling_factor_compute_algo is None
        ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
        assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX."

431
        self.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len
432
433
434
435
        string_to_amax_compute_algo = {
            "max": AmaxComputeAlgo.MAX,
            "most_recent": AmaxComputeAlgo.MOST_RECENT,
        }
436
        self.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo]
437

438
439
        self.FP8_2X_ACC_DGRAD = True
        self.FP8_2X_ACC_WGRAD = True
440

441
442
443
    def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
        """Gets the scaling mode for a specific tensor's usage type."""
        return ScalingMode.DELAYED_TENSOR_SCALING
444

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
    def get_quantize_flax_meta(
        self,
        module,
        collection_name: str,
        postfix: str,
        tensor_source: TensorSource,
        quantizer_name: str,
    ) -> QuantizeMeta:
        """Get the quantization metadata for a given Flax module.

        Args:
            module: The Flax module to get metadata for
            collection_name: The name of the collection to store metadata in
            postfix: Postfix to append to metadata names
            tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
            quantizer_name: The name of the quantizer within the module
        Returns:
            The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
        """
        scale = module.variable(
            collection_name,
            f"{quantizer_name}{postfix}_scale",
            jnp.ones,
            (1,),
            jnp.float32,
        ).value
        amax_history = module.variable(
            collection_name,
            f"{quantizer_name}{postfix}_amax_history",
            jnp.zeros,
            (self.AMAX_HISTORY_LEN,),
            jnp.float32,
        ).value
        return QuantizeMeta(scale=scale, amax_history=amax_history)

480

481
class CurrentScalingQuantizeConfig(BaseQuantizeConfig):
482
483
484
485
486
487
    """Configuration class for current scaling FP8 recipe.

    This class provides specific initialization and finalization for current scaling
    FP8 quantization mode.
    """

488
    def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
489
490
491
492
493
        """Initialize current scaling FP8 configuration.

        Args:
            fp8_recipe: The FP8 recipe to use for initialization
        """
494
495
        super().initialize_from_recipe(fp8_recipe)
        self.AMAX_HISTORY_LEN = 0
496

497
498
499
    def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
        """Gets the scaling mode for a specific tensor's usage type."""
        return ScalingMode.CURRENT_TENSOR_SCALING
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    def get_quantize_flax_meta(
        self,
        module,
        collection_name: str,
        postfix: str,
        tensor_source: TensorSource,
        quantizer_name: str,
    ) -> QuantizeMeta:
        """Get the quantization metadata for a given Flax module.

        Args:
            module: The Flax module to get metadata for
            collection_name: The name of the collection to store metadata in
            postfix: Postfix to append to metadata names
            tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
            quantizer_name: The name of the quantizer within the module
        Returns:
            The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
        """
        return QuantizeMeta()

522

523
class BlockScalingQuantizeConfig(BaseQuantizeConfig):
524
525
526
527
528
529
    """Configuration class for block scaling FP8 recipe.

    This class provides specific initialization and finalization for block scaling
    FP8 quantization mode.
    """

530
    def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
531
532
533
534
535
        """Initialize block scaling FP8 configuration.

        Args:
            fp8_recipe: The FP8 recipe to use for initialization
        """
536
537
538
539
540
541
542
        super().initialize_from_recipe(fp8_recipe)
        self.AMAX_HISTORY_LEN = 0

    def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
        """Gets the scaling mode for a specific tensor's usage type."""
        return ScalingMode.MXFP8_1D_SCALING

543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    def get_quantize_flax_meta(
        self,
        module,
        collection_name: str,
        postfix: str,
        tensor_source: TensorSource,
        quantizer_name: str,
    ) -> QuantizeMeta:
        """Get the quantization metadata for a given Flax module.

        Args:
            module: The Flax module to get metadata for
            collection_name: The name of the collection to store metadata in
            postfix: Postfix to append to metadata names
            tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
            quantizer_name: The name of the quantizer within the module
        Returns:
            The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
        """
        return QuantizeMeta()


565
@dataclass
566
567
568
569
570
571
class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
    """Configuration class for NVFP4 scaling recipe.

    This class provides specific initialization and finalization for NVFP4 scaling quantization mode.
    """

572
573
574
575
    DISABLE_STOCHASTIC_ROUNDING: bool = False
    DISABLE_RHT: bool = False
    DISABLE_2D_QUANTIZATION: bool = False

576
    def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
577
        """Initialize block scaling NVFP4 configuration.
578
579

        Args:
580
            fp8_recipe: The quantization recipe to use for initialization
581
        """
582
583
        assert isinstance(fp8_recipe, NVFP4BlockScaling)

584
585
586
587
        self.INITIALIZED = True
        self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format)
        self.AMAX_HISTORY_LEN = 0

588
589
590
591
        self.DISABLE_STOCHASTIC_ROUNDING = fp8_recipe.disable_stochastic_rounding
        self.DISABLE_RHT = fp8_recipe.disable_rht
        self.DISABLE_2D_QUANTIZATION = fp8_recipe.disable_2d_quantization

592
593
    def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
        """Gets the scaling mode for a specific tensor's usage type."""
594
        if (not self.DISABLE_2D_QUANTIZATION) and tensor_source == TensorSource.KERNEL:
595
596
597
598
            return ScalingMode.NVFP4_2D_SCALING
        # for x and grad
        return ScalingMode.NVFP4_1D_SCALING

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    def _make_rht_quantize_meta(self, q_layout, tensor_source: TensorSource) -> QuantizeMeta:
        """Create the quantization metadata for RHT if applicable."""
        # Imported here to prevent circular import
        from transformer_engine.jax.quantize import QuantizeLayout

        use_rht = self.get_scaling_mode(
            tensor_source
        ) == ScalingMode.NVFP4_1D_SCALING and q_layout in {
            QuantizeLayout.ROWWISE_COLWISE,
            QuantizeLayout.COLWISE,
        }
        if self.DISABLE_RHT:
            use_rht = False
        return QuantizeMeta(use_rht=use_rht)

    def _make_stochastic_rounding_rng_state(
        self, module, tensor_source: TensorSource, quantizer_name: str
    ) -> jnp.ndarray:
        """Create the stochastic rounding rng state if applicable."""
        if self.DISABLE_STOCHASTIC_ROUNDING:
            return QuantizeMeta()

        if tensor_source != TensorSource.DGRAD:
            # Only DGRAD uses stochastic rounding
            return QuantizeMeta()

        sr_jax_rng = module.make_rng("sr_rng")
        # Get a unique key for this quantizer
        # Use hashlib to get a deterministic hash value for quantizer_name
        quantizer_hash = (
            int(hashlib.sha256(quantizer_name.encode("utf-8")).hexdigest(), 16)
            % jnp.iinfo(jnp.int32).max
        )
        sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash)

        # Generate 4 random uint32 values from the JAX PRNG key
        shape = (4,)
        if get_num_devices_in_mesh() > 1:
            shape = (get_num_devices_in_mesh(), 4)
        sr_jax_rng_state = jax.random.randint(
            sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
        ).view(jnp.uint32)
        sr_jax_rng_state = with_sharding_constraint(
            sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
        )
        return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)

646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
    def get_quantize_flax_meta(
        self,
        module,
        collection_name: str,
        postfix: str,
        tensor_source: TensorSource,
        quantizer_name: str,
    ) -> QuantizeMeta:
        """Get the quantization metadata for a given Flax module.

        Args:
            module: The Flax module to get metadata for
            collection_name: The name of the collection to store metadata in
            postfix: Postfix to append to metadata names
            tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD)
            quantizer_name: The name of the quantizer within the module
        Returns:
            The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
        """
665
666
        # Imported here to prevent circular import
        from transformer_engine.jax.quantize import QuantizeLayout
667

668
669
670
        return QuantizeMeta.merge(
            self._make_rht_quantize_meta(QuantizeLayout.ROWWISE_COLWISE, tensor_source),
            self._make_stochastic_rounding_rng_state(module, tensor_source, quantizer_name),
671
672
        )

673
674
675

_QUANTIZE_CONFIG = NoOpQuantizeConfig()

676

677
def get_quantize_config():
678
    """Global instance of BaseQuantizeConfig set by autocast context."""
679
680
681
682
    return _QUANTIZE_CONFIG


def get_quantize_config_class(
683
    fp8_recipe: Recipe,
684
) -> Type[BaseQuantizeConfig]:
685
    """Get the quantization configuration class based on the FP8 recipe.
686
687
688
689
690
691

    Args:
        fp8_recipe: The FP8 recipe to use for initialization
    Returns:
        The quantization config class corresponding to the given recipe.
    """
692
    if isinstance(fp8_recipe, DelayedScaling):
693
        return DelayedScalingQuantizeConfig
694
    if isinstance(fp8_recipe, MXFP8BlockScaling):
695
        return BlockScalingQuantizeConfig
696
    if isinstance(fp8_recipe, Float8CurrentScaling):
697
        return CurrentScalingQuantizeConfig
698
    if isinstance(fp8_recipe, NVFP4BlockScaling):
699
        return NVFP4ScalingQuantizeConfig
700
    raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}")
701
702


703
def get_quantize_config_with_recipe(fp8_recipe: Recipe):
704
705
706
707
708
709
    """Get the quantization configuration object based on the FP8 recipe."""
    config = get_quantize_config_class(fp8_recipe)()
    config.initialize_from_recipe(fp8_recipe)
    return config


710
@contextmanager
711
def autocast(
712
    enabled: bool = False,
713
    recipe: Optional[Recipe] = None,
714
715
    mesh_resource: Optional[MeshResource] = None,
) -> None:
716
    r"""Context manager for FP8 or FP4 usage.
717

718
    This context manager enables quantization for the duration of its context.
719
720
        .. code-block:: python

721
722
723
724
            mesh_shape = (4, 2)
            dp_mesh_axis_name = 'data_parallel'
            tp_mesh_axis_name = 'tensor_parallel'
            devices = np.asarray(jax.devices()).reshape(*mesh_shape)
725

726
727
            with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
                mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
728

729
                with autocast(enabled=True, mesh_resource=mesh_resource):
730
731
                    rules = extend_logical_axis_rules(tuple())
                    transformer = TransformerLayer()
732

733
734
                    with partitioning.axis_rules(rules):
                        pjit(transformer.init, ...)(...)
735
736
737
738
739
740
741
742
743
744
745

    .. note::
        We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`,
        and :attr:`amax_compute_algo` (with value 'max' and 'most_recent') in
        recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
        will trigger an assertion.

    Parameters
    ----------
    enabled: bool, default = False
        Whether or not to enable fp8
746
747
    recipe: recipe.DelayedScaling, default = None
            recipe used for low precision quantization.
748
749
750
751
752
    mesh_resource: MeshResource, default = None
        Specify the mesh axes for data and tensor parallelism to shard along.
        If set to None, then no data or tensor parallelism will be used.

    """
753
754
    if recipe is None:
        recipe = DelayedScaling()
755

756
757
758
759
760
    global _QUANTIZE_CONFIG

    old_quantize_config = _QUANTIZE_CONFIG

    _QUANTIZE_CONFIG = NoOpQuantizeConfig()
761
762
763
764

    try:
        with global_shard_guard(mesh_resource):
            if enabled:
765
                _QUANTIZE_CONFIG = get_quantize_config_class(recipe)()
766
767
                is_supported, reason = _QUANTIZE_CONFIG.is_supported()
                assert is_supported, reason
768
                _QUANTIZE_CONFIG.initialize_from_recipe(recipe)
769
770
            yield
    finally:
771
        _QUANTIZE_CONFIG = old_quantize_config
772
773


774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
@contextmanager
def fp8_autocast(
    enabled: bool = False,
    fp8_recipe: Optional[Recipe] = None,
    mesh_resource: Optional[MeshResource] = None,
) -> None:
    """
    .. warning::

       fp8_autocast is deprecated and will be removed in a future release.
       Use autocast(enabled=..., recipe=..., mesh_resource=...) instead.

    """

    warnings.warn(
        "fp8_autocast is deprecated and will be removed in a future release. "
        "Use autocast(enabled=..., recipe=..., mesh_resource=...) instead.",
        category=DeprecationWarning,
        stacklevel=2,
    )

    # Call new implementation.
    with autocast(
        enabled=enabled,
        recipe=fp8_recipe,
        mesh_resource=mesh_resource,
    ):
        yield


804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
def update_collections(new: Collection, original: Collection) -> Collection:
    r"""Update collections with new values while preserving original structure.

    Args:
        new: New collection of values to add/update
        original: Original collection to update

    Returns:
        Updated collection with new values merged with original

    Raises:
        AssertionError: If either collection is not a dict or FrozenDict
    """
    assert isinstance(original, (dict, FrozenDict))
    assert isinstance(new, (dict, FrozenDict))
    frozen_original = FrozenDict(original) if not isinstance(original, FrozenDict) else original
    for key in new:
        if key in frozen_original:
            frozen_original, _ = frozen_original.pop(key)
    new_coll = FrozenDict({**new, **frozen_original})
    if not isinstance(original, FrozenDict):
        new_coll = new_coll.unfreeze()
    return new_coll


Alp Dener's avatar
Alp Dener committed
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
def remove_padding_from_scale_inv(
    scale_inv: jax.Array,
    scaling_mode: ScalingMode,
    data_shape: Sequence[int],
    is_colwise: bool = False,
    flatten_axis: int = -1,
):
    """
    Slice padding out of padded inverse scale factors.

    Args:
        scale_inv: Inverse scale factor.
        data_shape: Shape of the quantized data the inverse scale belongs to.
        scaling_mode: ScalingMode representing the quantization method.
        is_colwise: Whether the data was quantized column-wise.
        flatten_axis: The axis along with the data could be flattened to 2D.

    Returns:
        Inverse scale factor without padding.
    """
    # Get expected unpadded scale shape and check if inverse scale already matches
    unpadded_scale_shape = scaling_mode.get_scale_shape(
        data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
    )
    if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape:
        return scale_inv

    # Get the padded scale shape and make sure inverse scale matches
    padded_scale_shape = scaling_mode.get_scale_shape(
        data_shape,
        is_colwise=is_colwise,
        is_padded=True,
        flatten_axis=flatten_axis,
    )
    assert scale_inv.shape == padded_scale_shape, (
        f"Padded inverse scale factor has wrong shape, expected {padded_scale_shape} but got "
        f"{scale_inv.shape} instead."
    )

    # Reshape scale inverse to 2D in two stages to preserve the flatten axis
    padded_scale_shape_2d = (
        reduce(operator.mul, padded_scale_shape[:flatten_axis]),
        reduce(operator.mul, padded_scale_shape[flatten_axis:]),
    )
    scale_inv_2d = jnp.reshape(
        jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis:])),
        padded_scale_shape_2d,
    )

    # Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape
    unpadded_scale_shape_2d = (
        reduce(operator.mul, unpadded_scale_shape[:flatten_axis]),
        reduce(operator.mul, unpadded_scale_shape[flatten_axis:]),
    )
    scale_inv_2d_unpadded = jnp.asarray(
        scale_inv_2d[: unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]]
    )

    # Reshape 2D scale inverse back in two stages in order to preserve the flatten axis
    scale_inv_unpadded = jnp.reshape(
        jnp.reshape(
            scale_inv_2d_unpadded,
            (*unpadded_scale_shape[:flatten_axis], scale_inv_2d_unpadded.shape[1]),
        ),
        unpadded_scale_shape,
    )
    return scale_inv_unpadded


def apply_padding_to_scale_inv(
    scale_inv: jax.Array,
    scaling_mode: ScalingMode,
    data_shape: Sequence[int],
    is_colwise: bool = False,
    flatten_axis: int = -1,
):
    """
    Pad the scale inverse with zeros to match the necessary padded shape for this scaling
    mode.

    Args:
        scale_inv: Inverse scale factor.
        data_shape: Shape of the quantized data the inverse scale belongs to.
        scaling_mode: ScalingMode representing the quantization method.
        is_colwise: Whether the data was quantized column-wise.
        flatten_axis: The axis along with the data could be flattened to 2D.

    Returns:
        Padded inverse scale factor.
    """
    # Get the expected padded scale shape and check if inverse scale already matches
    padded_scale_shape = scaling_mode.get_scale_shape(
        data_shape, is_colwise=is_colwise, is_padded=True, flatten_axis=flatten_axis
    )
    if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == padded_scale_shape:
        return scale_inv

    # Get the expected unpadded scale shape and make sure inverse scales match
    unpadded_scale_shape = scaling_mode.get_scale_shape(
        data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
    )
    assert scale_inv.shape == unpadded_scale_shape, (
        f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got "
        f"{scale_inv.shape}."
    )

    # Pad the scales with the lowest representable value (2^-127) and return
    pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape))
    return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127)