common.py 5.91 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
from functools import lru_cache
6
7
8
import os
import pathlib
import zipfile
9

10
11
import jax
import jax.numpy
12
import transformer_engine
13
from transformer_engine_jax import get_device_compute_capability
14
from transformer_engine.common import recipe
15
import numpy as np
16
17
18
19
20
21
22


@lru_cache
def is_bf16_supported():
    """Return if BF16 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 80
23
24
25
26
27
28
29


@lru_cache
def is_fp8_supported():
    """Return if FP8 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 90
30
31
32
33
34
35
36
37
38


@lru_cache
def is_mxfp8_supported():
    """Return if FP8 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 100


39
40
41
42
43
44
45
@lru_cache
def is_nvfp4_supported():
    """Return if FP8 has hardware supported"""
    gpu_arch = get_device_compute_capability(0)
    return gpu_arch >= 100


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False):
    """Checks whether most params are sharded across sharding axis.

    (Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/315e551e5942b24656a4250dcfca986fb4135b72/MaxText/maxtext_utils.py#L348)

    This function determines whether the majority of parameters are distributed
    across a specified sharding axes with an acceptable tolerance. It compares the
    current distribution to a scenario where all parameters are fully sharded
    across the axes on which the params are sharded e.g. 'tensor' axis.

    Args:
        params: params of the model state
        mesh: mesh constructed from config
        tolerance: float between 0.0 and 1.0 representing the allowed percentage of
        non-sharded parameters.
    """

    def get_product_num_devices_for_weight_sharding(weight_sharding_axes):
        product_num_devices_for_weight_sharding = 1
        for axis in weight_sharding_axes:
            product_num_devices_for_weight_sharding *= mesh.shape.get(axis, 1)
        return product_num_devices_for_weight_sharding

    def assert_leaf_sharding(path, arr):

        # Is the weight sharded? Get the axes on which it is sharded.
        partition_spec = arr.sharding.spec
        weight_sharding_axes = set(partition_spec) - set([None])  # None is not a sharding axis

        # Total number of devices on the axes on which the weight is sharded.
        product_num_devices_for_weight_sharding = get_product_num_devices_for_weight_sharding(
            weight_sharding_axes
        )

        # Params present in one shard (on one device).
        shard = arr.addressable_shards[0]
        params_per_chip = np.prod(shard.data.shape)

        # Total number of params (across all devicess).
        total_params = jax.numpy.size(arr)

        # Percentage of params that are unsharded.
        unsharded_perc = (
            (params_per_chip / (total_params / product_num_devices_for_weight_sharding) - 1) * 100
            if params_per_chip < total_params
            else 100
        )

        if print_info:
            print(
                f"{path}: {unsharded_perc:.2f}% unsharded, unsharded param shape={arr.shape},"
                f" partition spec={partition_spec}"
            )

        # If the weight is sharded on any axis, then the percentage of
        # unsharded params should be less than the tolerance.
        assert (
            product_num_devices_for_weight_sharding == 1 or unsharded_perc < tolerance
        ), f"{path}: {unsharded_perc:.2f}% unsharded"

    jax.tree_util.tree_map_with_path(
        lambda p, x: assert_leaf_sharding("/".join(str(x) for x in p), x), params
    )


111
def get_quantization_recipe_from_name_string(name: str):
112
113
114
115
116
117
    """Query recipe from a given name string"""
    match name:
        case "DelayedScaling":
            return recipe.DelayedScaling()
        case "MXFP8BlockScaling":
            return recipe.MXFP8BlockScaling()
118
119
        case "Float8CurrentScaling":
            return recipe.Float8CurrentScaling()
120
121
        case "NVFP4BlockScaling":
            return recipe.NVFP4BlockScaling()
122
        case _:
123
            raise ValueError(f"Invalid quantization_recipe, got {name}")
124
125


126
127
128
@lru_cache(maxsize=None)
def _get_example_artifacts_dir() -> pathlib.Path:
    """Path to directory with pre-downloaded datasets"""
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    # Check environment variable
    path = os.getenv("NVTE_TEST_CHECKPOINT_ARTIFACT_PATH")
    if path:
        return pathlib.Path(path).resolve()

    # Fallback to path in root dir
    root_dir = pathlib.Path(__file__).resolve().parent.parent.parent
    return root_dir / "artifacts" / "examples" / "jax"


def _unpack_cached_dataset(artifacts_dir: pathlib.Path, folder_name: str) -> None:
    """Unpack a cached dataset if available"""
    dataset_dir = artifacts_dir / folder_name
    if not dataset_dir.exists():
        print(f"Cached dataset {folder_name} not found at {dataset_dir}, skipping unpack")
        return

    # Disable any HF network calls since the dataset is cached locally
    os.environ["HF_HUB_OFFLINE"] = "1"

    for filename in os.listdir(dataset_dir):
        filepath = dataset_dir / filename
        if not filename.endswith(".zip"):
            continue
        print(f"Unpacking cached dataset {folder_name} from {filepath}")

        with zipfile.ZipFile(filepath, "r") as zip_ref:
            zip_ref.extractall(pathlib.Path.home() / ".cache" / "huggingface")
        print(
            f"Unpacked cached dataset {folder_name} to"
            f" {pathlib.Path.home() / '.cache' / 'huggingface'}"
        )


# This is cached so we don't have to unpack datasets multiple times
@lru_cache(maxsize=None)
def unpack_cached_datasets_if_available() -> None:
    """Unpack cached datasets if available"""
    artifacts_dir = _get_example_artifacts_dir()
    _unpack_cached_dataset(artifacts_dir, "mnist")
    _unpack_cached_dataset(artifacts_dir, "encoder")