Unverified Commit dcaca2a6 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Try to use pre-downloaded dataset artifacts first (#2345)



* Try to use pre-downloaded dataset artifacts first
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Set HF_HUB_OFFLINE to disable any network calls to HF when the
pre-downloaded dataset is available
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent b6020e3b
......@@ -3,6 +3,9 @@
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
from functools import lru_cache
import os
import pathlib
import zipfile
import jax
import jax.numpy
......@@ -120,12 +123,48 @@ def get_quantization_recipe_from_name_string(name: str):
raise ValueError(f"Invalid quantization_recipe, got {name}")
def hf_login_if_available():
"""Login to HF hub if available"""
try:
from huggingface_hub import login
@lru_cache(maxsize=None)
def _get_example_artifacts_dir() -> pathlib.Path:
"""Path to directory with pre-downloaded datasets"""
login()
except Exception as e:
print(e)
pass
# Check environment variable
path = os.getenv("NVTE_TEST_CHECKPOINT_ARTIFACT_PATH")
if path:
return pathlib.Path(path).resolve()
# Fallback to path in root dir
root_dir = pathlib.Path(__file__).resolve().parent.parent.parent
return root_dir / "artifacts" / "examples" / "jax"
def _unpack_cached_dataset(artifacts_dir: pathlib.Path, folder_name: str) -> None:
"""Unpack a cached dataset if available"""
dataset_dir = artifacts_dir / folder_name
if not dataset_dir.exists():
print(f"Cached dataset {folder_name} not found at {dataset_dir}, skipping unpack")
return
# Disable any HF network calls since the dataset is cached locally
os.environ["HF_HUB_OFFLINE"] = "1"
for filename in os.listdir(dataset_dir):
filepath = dataset_dir / filename
if not filename.endswith(".zip"):
continue
print(f"Unpacking cached dataset {folder_name} from {filepath}")
with zipfile.ZipFile(filepath, "r") as zip_ref:
zip_ref.extractall(pathlib.Path.home() / ".cache" / "huggingface")
print(
f"Unpacked cached dataset {folder_name} to"
f" {pathlib.Path.home() / '.cache' / 'huggingface'}"
)
# This is cached so we don't have to unpack datasets multiple times
@lru_cache(maxsize=None)
def unpack_cached_datasets_if_available() -> None:
"""Unpack cached datasets if available"""
artifacts_dir = _get_example_artifacts_dir()
_unpack_cached_dataset(artifacts_dir, "mnist")
_unpack_cached_dataset(artifacts_dir, "encoder")
......@@ -23,14 +23,14 @@ from common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
assert_params_sufficiently_sharded,
hf_login_if_available,
unpack_cached_datasets_if_available,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
hf_login_if_available()
unpack_cached_datasets_if_available()
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
......
......@@ -22,14 +22,14 @@ from jax.sharding import PartitionSpec, NamedSharding
from common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
unpack_cached_datasets_if_available,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
hf_login_if_available()
unpack_cached_datasets_if_available()
DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
......
......@@ -27,13 +27,13 @@ from common import (
is_mxfp8_supported,
is_nvfp4_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
unpack_cached_datasets_if_available,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
hf_login_if_available()
unpack_cached_datasets_if_available()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data"
......
......@@ -19,13 +19,13 @@ from flax.training import train_state
from common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
unpack_cached_datasets_if_available,
)
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
hf_login_if_available()
unpack_cached_datasets_if_available()
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
......
......@@ -25,10 +25,10 @@ sys.path.append(str(DIR))
from encoder.common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
unpack_cached_datasets_if_available,
)
hf_login_if_available()
unpack_cached_datasets_if_available()
IMAGE_H = 28
IMAGE_W = 28
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment