"src/lib/vscode:/vscode.git/clone" did not exist on "c5692808066f0b57cc28d6f80e1e9333e156e753"
Unverified Commit ce2f9fa4 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] HuggingFace login in JAX examples if token is available (#2290)



HF login in JAX examples
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent e90582f2
......@@ -118,3 +118,14 @@ def get_quantization_recipe_from_name_string(name: str):
return recipe.NVFP4BlockScaling()
case _:
raise ValueError(f"Invalid quantization_recipe, got {name}")
def hf_login_if_available():
"""Login to HF hub if available"""
try:
from huggingface_hub import login
login()
except Exception as e:
print(e)
pass
......@@ -23,12 +23,14 @@ from common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
assert_params_sufficiently_sharded,
hf_login_if_available,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
hf_login_if_available()
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
......
......@@ -19,12 +19,17 @@ from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_quantization_recipe_from_name_string
from common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
hf_login_if_available()
DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
......
......@@ -27,11 +27,13 @@ from common import (
is_mxfp8_supported,
is_nvfp4_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
hf_login_if_available()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data"
......
......@@ -16,11 +16,16 @@ from datasets import load_dataset
from flax import linen as nn
from flax.training import train_state
from common import is_bf16_supported, get_quantization_recipe_from_name_string
from common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
)
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
hf_login_if_available()
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
......
......@@ -22,7 +22,13 @@ from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMo
DIR = str(Path(__file__).resolve().parents[1])
sys.path.append(str(DIR))
from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string
from encoder.common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
)
hf_login_if_available()
IMAGE_H = 28
IMAGE_W = 28
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment