metadata.py 1.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Metadata classes for quantization in JAX.

This module provides classes for managing quantization metadata, including
scale factors and amax history for different tensor types.
"""
from dataclasses import dataclass


__all__ = ["QuantizeMeta", "QuantizeMetaSet"]


class QuantizeMeta:
    """Metadata for quantization parameters.

20
    For Delayed Scaling recipe:
21
22
        scale: The scaling factor for quantization
        amax_history: History of maximum absolute values
23
24
25
26

    For NVFP4 recipe with Stochastic Rounding:
        sr_rng_state: The state of the stochastic rounding RNG

27
28
    """

29
30
31
32
33
34
    def __init__(self, **kwargs):
        self._kwargs = kwargs

    def get_kwargs_dictionary(self):
        """Get the metadata as a dictionary."""
        return self._kwargs
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


@dataclass
class QuantizeMetaSet:
    """Set of quantization metadata for different tensor types.

    Attributes:
        x: Quantization metadata for input tensors
        kernel: Quantization metadata for kernel tensors
        grad: Quantization metadata for gradient tensors
    """

    x: QuantizeMeta
    kernel: QuantizeMeta
    grad: QuantizeMeta