Unverified Commit 195d7032 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Allow downloading of model weights automatically (#1172)



* allow tutorial to download the model weights automatically
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* allow users to provide weight cache directory
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0ee5ccda
...@@ -102,8 +102,11 @@ class TELlamaForCausalLM: ...@@ -102,8 +102,11 @@ class TELlamaForCausalLM:
Custom method adapted from `from_pretrained` method in HuggingFace Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
""" """
vanilla_model = cls(config).to(kwargs["torch_dtype"]) # Before loading the model, set the default dtype for torch
is_local = os.path.isdir(pretrained_model_name_or_path) torch.set_default_dtype(kwargs["torch_dtype"])
# Load the vanilla model weights
vanilla_model = cls(config)
subfolder = "" subfolder = ""
variant = None variant = None
if os.path.isfile( if os.path.isfile(
...@@ -133,7 +136,7 @@ class TELlamaForCausalLM: ...@@ -133,7 +136,7 @@ class TELlamaForCausalLM:
else: else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path, pretrained_model_name_or_path,
archive_file, archive_file,
) )
......
...@@ -247,15 +247,24 @@ ...@@ -247,15 +247,24 @@
"restart_jupyter_notebook()\n", "restart_jupyter_notebook()\n",
"\n", "\n",
"\n", "\n",
"# Import necessary packages and methods\n", "# Import necessary packages, methods and variables\n",
"from utils import *\n", "from utils import *\n",
"\n", "\n",
"\n", "\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "# Provide Huggingface Access Token\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n", "hyperparams.hf_access_token = \"\"\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", "\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n", "hyperparams.mixed_precision = \"bf16\"\n",
"\n", "\n",
"\n", "\n",
...@@ -554,7 +563,7 @@ ...@@ -554,7 +563,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "bdb34b91", "id": "bdb34b91",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -573,15 +582,24 @@ ...@@ -573,15 +582,24 @@
"restart_jupyter_notebook()\n", "restart_jupyter_notebook()\n",
"\n", "\n",
"\n", "\n",
"# Import necessary packages and methods\n", "# Import necessary packages, methods and variables\n",
"from utils import *\n", "from utils import *\n",
"\n", "\n",
"\n", "\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "# Provide Huggingface Access Token\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n", "hyperparams.hf_access_token = \"\"\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", "\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"bf16\"\n", "hyperparams.mixed_precision = \"bf16\"\n",
"\n", "\n",
"\n", "\n",
...@@ -653,15 +671,24 @@ ...@@ -653,15 +671,24 @@
"restart_jupyter_notebook()\n", "restart_jupyter_notebook()\n",
"\n", "\n",
"\n", "\n",
"# Import necessary packages and methods\n", "# Import necessary packages, methods and variables\n",
"from utils import *\n", "from utils import *\n",
"\n", "\n",
"\n", "\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "# Provide Huggingface Access Token\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n", "hyperparams.hf_access_token = \"\"\n",
"# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n", "assert hyperparams.hf_access_token, \"Provide a HF API Access Token!\"\n",
"# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n", "\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"hyperparams.weights_cache_dir = \"\"\n",
"\n",
"# For Llama 2, uncomment this line (also set by default)\n",
"hyperparams.model_name = \"meta-llama/Llama-2-7b-hf\"\n",
"\n",
"# For Llama 3, uncomment this line\n",
"# hyperparams.model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
"\n",
"hyperparams.mixed_precision = \"fp8\"\n", "hyperparams.mixed_precision = \"fp8\"\n",
"\n", "\n",
"\n", "\n",
......
...@@ -25,7 +25,10 @@ from accelerate.utils.dataclasses import FP8RecipeKwargs ...@@ -25,7 +25,10 @@ from accelerate.utils.dataclasses import FP8RecipeKwargs
class HyperParameters: class HyperParameters:
def __init__(self): def __init__(self):
self.mixed_precision = "bf16" self.mixed_precision = "bf16"
# self.model_name = "" # <== Add model weight location here
# Set to Meta Llama 2 by default.
self.model_name = "meta-llama/Llama-2-7b-hf"
self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text" self.dataset_text_field = "text"
self.learning_rate = 1.41e-5 self.learning_rate = 1.41e-5
...@@ -35,6 +38,10 @@ class HyperParameters: ...@@ -35,6 +38,10 @@ class HyperParameters:
self.num_warmup_steps = 5 self.num_warmup_steps = 5
self.num_training_steps = 10 self.num_training_steps = 10
# This is either provided by the user or it will be set when the
# model weights are downloaded.
self.weights_cache_dir = ""
hyperparams = HyperParameters() hyperparams = HyperParameters()
...@@ -76,13 +83,49 @@ def get_dataloaders(accelerator: Accelerator, hyperparams): ...@@ -76,13 +83,49 @@ def get_dataloaders(accelerator: Accelerator, hyperparams):
return train_dataloader return train_dataloader
def ensure_model_is_downloaded(hyperparams):
assert hyperparams.model_name in [
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Llama-2-7b-hf",
], "Only Meta Llama 2 7B and Meta Llama 3 8B models are supported!"
# Login using Huggingface Hub API
from huggingface_hub import login
try:
login(hyperparams.hf_access_token)
except Exception as e:
if "Invalid token passed!" in str(e):
print(
"Please pass a valid HF Access Token! More info at"
" https://huggingface.co/docs/hub/en/security-tokens."
)
else:
print(f"Exception is {e}")
# Download the model if it doesn't exist
from huggingface_hub import snapshot_download
supplied_cache_dir = (
hyperparams.weights_cache_dir if hyperparams.weights_cache_dir != "" else None
)
hyperparams.weights_cache_dir = snapshot_download(
repo_id=hyperparams.model_name, cache_dir=supplied_cache_dir
)
print(f"Model cache directory : {hyperparams.weights_cache_dir}")
def init_baseline_model(hyperparams): def init_baseline_model(hyperparams):
# Download and cache the weights
ensure_model_is_downloaded(hyperparams)
# Init the model # Init the model
config = AutoConfig.from_pretrained(hyperparams.model_name) config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
# make sure to use flash_attention to do iso comparison with TELlamaModel # make sure to use flash_attention to do iso comparison with TELlamaModel
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
hyperparams.model_name, hyperparams.weights_cache_dir,
config=config, config=config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
) )
...@@ -94,13 +137,16 @@ def init_baseline_model(hyperparams): ...@@ -94,13 +137,16 @@ def init_baseline_model(hyperparams):
def init_te_llama_model(hyperparams): def init_te_llama_model(hyperparams):
# Download and cache the weights
ensure_model_is_downloaded(hyperparams)
# Init the model # Init the model
from te_llama import TELlamaForCausalLM from te_llama import TELlamaForCausalLM
config = AutoConfig.from_pretrained(hyperparams.model_name) config = AutoConfig.from_pretrained(hyperparams.weights_cache_dir)
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local( model = TELlamaForCausalLM.from_pretrained_local(
hyperparams.model_name, hyperparams.weights_cache_dir,
config=config, config=config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment