common.py 4.19 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
from functools import lru_cache

7
8
import jax
import jax.numpy
9
import transformer_engine
10
from transformer_engine_jax import get_device_compute_capability
11
from transformer_engine.common import recipe
12
import numpy as np
13
14
15
16
17
18
19


@lru_cache
def is_bf16_supported():
    """Return if BF16 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 80
20
21
22
23
24
25
26


@lru_cache
def is_fp8_supported():
    """Return if FP8 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 90
27
28
29
30
31
32
33
34
35


@lru_cache
def is_mxfp8_supported():
    """Return if FP8 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 100


36
37
38
39
40
41
42
@lru_cache
def is_nvfp4_supported():
    """Return if FP8 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 100


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False):
    """Checks whether most params are sharded across sharding axis.

    (Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/315e551e5942b24656a4250dcfca986fb4135b72/MaxText/maxtext_utils.py#L348)

    This function determines whether the majority of parameters are distributed
    across a specified sharding axes with an acceptable tolerance. It compares the
    current distribution to a scenario where all parameters are fully sharded
    across the axes on which the params are sharded e.g. 'tensor' axis.

    Args:
        params: params of the model state
        mesh: mesh constructed from config
        tolerance: float between 0.0 and 1.0 representing the allowed percentage of
        non-sharded parameters.
    """

    def get_product_num_devices_for_weight_sharding(weight_sharding_axes):
        product_num_devices_for_weight_sharding = 1
        for axis in weight_sharding_axes:
            product_num_devices_for_weight_sharding *= mesh.shape.get(axis, 1)
        return product_num_devices_for_weight_sharding

    def assert_leaf_sharding(path, arr):

        # Is the weight sharded? Get the axes on which it is sharded.
        partition_spec = arr.sharding.spec
        weight_sharding_axes = set(partition_spec) - set([None])  # None is not a sharding axis

        # Total number of devices on the axes on which the weight is sharded.
        product_num_devices_for_weight_sharding = get_product_num_devices_for_weight_sharding(
            weight_sharding_axes
        )

        # Params present in one shard (on one device).
        shard = arr.addressable_shards[0]
        params_per_chip = np.prod(shard.data.shape)

        # Total number of params (across all devicess).
        total_params = jax.numpy.size(arr)

        # Percentage of params that are unsharded.
        unsharded_perc = (
            (params_per_chip / (total_params / product_num_devices_for_weight_sharding) - 1) * 100
            if params_per_chip < total_params
            else 100
        )

        if print_info:
            print(
                f"{path}: {unsharded_perc:.2f}% unsharded, unsharded param shape={arr.shape},"
                f" partition spec={partition_spec}"
            )

        # If the weight is sharded on any axis, then the percentage of
        # unsharded params should be less than the tolerance.
        assert (
            product_num_devices_for_weight_sharding == 1 or unsharded_perc < tolerance
        ), f"{path}: {unsharded_perc:.2f}% unsharded"

    jax.tree_util.tree_map_with_path(
        lambda p, x: assert_leaf_sharding("/".join(str(x) for x in p), x), params
    )


108
def get_quantization_recipe_from_name_string(name: str):
109
110
111
112
113
114
    """Query recipe from a given name string"""
    match name:
        case "DelayedScaling":
            return recipe.DelayedScaling()
        case "MXFP8BlockScaling":
            return recipe.MXFP8BlockScaling()
115
116
        case "Float8CurrentScaling":
            return recipe.Float8CurrentScaling()
117
118
        case "NVFP4BlockScaling":
            return recipe.NVFP4BlockScaling()
119
        case _:
120
            raise ValueError(f"Invalid quantization_recipe, got {name}")