Unverified Commit e3d4901b authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Feat] Add _set_model_kwargs_torch_dtype for HF model (#507)

* add _set_model_kwargs_torch_dtype for hf models

* add logger
parent 6405cd2d
...@@ -131,13 +131,28 @@ class HuggingFace(BaseModel): ...@@ -131,13 +131,28 @@ class HuggingFace(BaseModel):
self.tokenizer.eos_token = '</s>' self.tokenizer.eos_token = '</s>'
self.tokenizer.pad_token_id = 0 self.tokenizer.pad_token_id = 0
def _set_model_kwargs_torch_dtype(self, model_kwargs):
if 'torch_dtype' not in model_kwargs:
torch_dtype = torch.float16
else:
torch_dtype = {
'torch.float16': torch.float16,
'torch.bfloat16': torch.bfloat16,
'torch.float': torch.float,
'auto': 'auto',
'None': None
}.get(model_kwargs['torch_dtype'])
self.logger.debug(f'HF using torch_dtype: {torch_dtype}')
if torch_dtype is not None:
model_kwargs['torch_dtype'] = torch_dtype
def _load_model(self, def _load_model(self,
path: str, path: str,
model_kwargs: dict, model_kwargs: dict,
peft_path: Optional[str] = None): peft_path: Optional[str] = None):
from transformers import AutoModel, AutoModelForCausalLM from transformers import AutoModel, AutoModelForCausalLM
model_kwargs.setdefault('torch_dtype', torch.float16) self._set_model_kwargs_torch_dtype(model_kwargs)
try: try:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
path, **model_kwargs) path, **model_kwargs)
...@@ -409,7 +424,7 @@ class HuggingFaceCausalLM(HuggingFace): ...@@ -409,7 +424,7 @@ class HuggingFaceCausalLM(HuggingFace):
peft_path: Optional[str] = None): peft_path: Optional[str] = None):
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
model_kwargs.setdefault('torch_dtype', torch.float16) self._set_model_kwargs_torch_dtype(model_kwargs)
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
if peft_path is not None: if peft_path is not None:
from peft import PeftModel from peft import PeftModel
......
...@@ -175,8 +175,14 @@ def parse_hf_args(hf_parser): ...@@ -175,8 +175,14 @@ def parse_hf_args(hf_parser):
hf_parser.add_argument('--hf-path', type=str) hf_parser.add_argument('--hf-path', type=str)
hf_parser.add_argument('--peft-path', type=str) hf_parser.add_argument('--peft-path', type=str)
hf_parser.add_argument('--tokenizer-path', type=str) hf_parser.add_argument('--tokenizer-path', type=str)
hf_parser.add_argument('--model-kwargs', nargs='+', action=DictAction) hf_parser.add_argument('--model-kwargs',
hf_parser.add_argument('--tokenizer-kwargs', nargs='+', action=DictAction) nargs='+',
action=DictAction,
default={})
hf_parser.add_argument('--tokenizer-kwargs',
nargs='+',
action=DictAction,
default={})
hf_parser.add_argument('--max-out-len', type=int) hf_parser.add_argument('--max-out-len', type=int)
hf_parser.add_argument('--max-seq-len', type=int) hf_parser.add_argument('--max-seq-len', type=int)
hf_parser.add_argument('--no-batch-padding', hf_parser.add_argument('--no-batch-padding',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment