Unverified Commit 9c9fe89f authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

[run_clm example] add torch_dtype option for model load. (#20971)



* [run_clm example] add torch_dtype option for model load.
for BLOOM 175B model. peak memory will reduce about 350G for inference. the weight of BLOOM in model hub is bfloat16
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* add other type in option

* fix style
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent e697c912
......@@ -30,6 +30,7 @@ from itertools import chain
from typing import Optional
import datasets
import torch
from datasets import load_dataset
import evaluate
......@@ -119,6 +120,16 @@ class ModelArguments:
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
......@@ -374,6 +385,11 @@ def main():
)
if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
......@@ -381,6 +397,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch_dtype,
)
else:
model = AutoModelForCausalLM.from_config(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment