Unverified Commit ce2f9fa4 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] HuggingFace login in JAX examples if token is available (#2290)



HF login in JAX examples
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent e90582f2
...@@ -118,3 +118,14 @@ def get_quantization_recipe_from_name_string(name: str): ...@@ -118,3 +118,14 @@ def get_quantization_recipe_from_name_string(name: str):
return recipe.NVFP4BlockScaling() return recipe.NVFP4BlockScaling()
case _: case _:
raise ValueError(f"Invalid quantization_recipe, got {name}") raise ValueError(f"Invalid quantization_recipe, got {name}")
def hf_login_if_available():
"""Login to HF hub if available"""
try:
from huggingface_hub import login
login()
except Exception as e:
print(e)
pass
...@@ -23,12 +23,14 @@ from common import ( ...@@ -23,12 +23,14 @@ from common import (
is_bf16_supported, is_bf16_supported,
get_quantization_recipe_from_name_string, get_quantization_recipe_from_name_string,
assert_params_sufficiently_sharded, assert_params_sufficiently_sharded,
hf_login_if_available,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
hf_login_if_available()
DEVICE_DP_AXIS = "data" DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model" DEVICE_TP_AXIS = "model"
......
...@@ -19,12 +19,17 @@ from flax.training import train_state ...@@ -19,12 +19,17 @@ from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_quantization_recipe_from_name_string from common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
)
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
hf_login_if_available()
DEVICE_DP_AXIS = "data" DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params" PARAMS_KEY = "params"
......
...@@ -27,11 +27,13 @@ from common import ( ...@@ -27,11 +27,13 @@ from common import (
is_mxfp8_supported, is_mxfp8_supported,
is_nvfp4_supported, is_nvfp4_supported,
get_quantization_recipe_from_name_string, get_quantization_recipe_from_name_string,
hf_login_if_available,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
hf_login_if_available()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data" DEVICE_DP_AXIS = "data"
......
...@@ -16,11 +16,16 @@ from datasets import load_dataset ...@@ -16,11 +16,16 @@ from datasets import load_dataset
from flax import linen as nn from flax import linen as nn
from flax.training import train_state from flax.training import train_state
from common import is_bf16_supported, get_quantization_recipe_from_name_string from common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
)
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
hf_login_if_available()
PARAMS_KEY = "params" PARAMS_KEY = "params"
DROPOUT_KEY = "dropout" DROPOUT_KEY = "dropout"
......
...@@ -22,7 +22,13 @@ from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMo ...@@ -22,7 +22,13 @@ from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMo
DIR = str(Path(__file__).resolve().parents[1]) DIR = str(Path(__file__).resolve().parents[1])
sys.path.append(str(DIR)) sys.path.append(str(DIR))
from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string from encoder.common import (
is_bf16_supported,
get_quantization_recipe_from_name_string,
hf_login_if_available,
)
hf_login_if_available()
IMAGE_H = 28 IMAGE_H = 28
IMAGE_W = 28 IMAGE_W = 28
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment