Unverified Commit 195d7032 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Allow downloading of model weights automatically (#1172)



* allow tutorial to download the model weights automatically
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* allow users to provide weight cache directory
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0ee5ccda
......@@ -102,8 +102,11 @@ class TELlamaForCausalLM:
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
vanilla_model = cls(config).to(kwargs["torch_dtype"])
is_local = os.path.isdir(pretrained_model_name_or_path)
# Before loading the model, set the default dtype for torch
torch.set_default_dtype(kwargs["torch_dtype"])
# Load the vanilla model weights
vanilla_model = cls(config)
subfolder = ""
variant = None
if os.path.isfile(
......@@ -133,7 +136,7 @@ class TELlamaForCausalLM:
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path,
archive_file,
)
......
......@@ -247,15 +247,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
......@@ -554,7 +563,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "bdb34b91",
"metadata": {},
"outputs": [
......@@ -573,15 +582,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n",
"\n",
"\n",
......@@ -653,15 +671,24 @@
"restart_jupyter_notebook()\n",
"\n",
"\n",
"# Import necessary packages and methods\n",
"# Import necessary packages, methods and variables\n",
"from utils import *\n",
"\n",
"\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"# Provide Huggingface Access Token\n",
"hyperparams.hf_access_token = \"\"\n",
"assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"fp8\"\n",
"\n",
"\n",
......
......@@ -25,7 +25,10 @@ from accelerate.utils.dataclasses import FP8RecipeKwargs
class HyperParameters:
def __init__(self):
self.mixed_precision = "bf16"
# self.model_name = "" # <== Add model weight location here
# Set to Meta Llama 2 by default.
self.model_name = "meta-llama/Llama-2-7b-hf"
self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text"
self.learning_rate = 1.41e-5
......@@ -35,6 +38,10 @@ class HyperParameters:
self.num_warmup_steps = 5
self.num_training_steps = 10
# This is either provided by the user or it will be set when the
# model weights are downloaded.
self.weights_cache_dir = ""
hyperparams = HyperParameters()
......@@ -76,13 +83,49 @@ def get_dataloaders(accelerator: Accelerator, hyperparams):
return train_dataloader
def ensure_model_is_downloaded(hyperparams):
assert hyperparams.model_name in [
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Llama-2-7b-hf",
], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!"
# Login using Huggingface Hub API
from huggingface_hub import login
try:
login(hyperparams.hf_access_token)
except Exception as e:
if "Invalid token passed!" in str(e):
print(
"Please pass a valid HF Access Token! More info at"
" https://huggingface.co/docs/hub/en/security-tokens."
)
else:
print(f"Exception is {e}")
# Download the model if it doesn't exist
from huggingface_hub import snapshot_download
supplied_cache_dir = (
hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None
)
hyperparams.weights_cache_dir = snapshot_download(
repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir
)
print(f"Model cache directory : {hyperparams.weights_cache_dir}")
def init_baseline_model(hyperparams):
# Download and cache the weights
ensure_model_is_downloaded(hyperparams)
# Init the model
config = AutoConfig.from_pretrained(hyperparams.model_name)
config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
# make sure to use flash_attention to do iso comparison with TELlamaModel
config._attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
hyperparams.model_name,
hyperparams.weights_cache_dir,
config=config,
torch_dtype=torch.bfloat16,
)
......@@ -94,13 +137,16 @@ def init_baseline_model(hyperparams):
def init_te_llama_model(hyperparams):
# Download and cache the weights
ensure_model_is_downloaded(hyperparams)
# Init the model
from te_llama import TELlamaForCausalLM
config = AutoConfig.from_pretrained(hyperparams.model_name)
config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local(
hyperparams.model_name,
hyperparams.weights_cache_dir,
config=config,
torch_dtype=torch.bfloat16,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment