Commit dd2b2de9 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

re-add dtype flag

parent 1fb16673
......@@ -12,7 +12,7 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from accelerate import Accelerator
from itertools import islice
from typing import Optional, Union
@register_model("hf-causal")
......@@ -23,6 +23,7 @@ class HFLM(LM):
pretrained="gpt2",
revision="main",
low_cpu_mem_usage=None,
dtype: Optional[Union[str, torch.dtype]]="auto",
subfolder=None,
tokenizer=None,
batch_size=1,
......@@ -58,10 +59,15 @@ class HFLM(LM):
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
pretrained,
revision=revision,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=utils.get_dtype(dtype),
).to(self.device)
self.model.eval()
print(self.model.dtype)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
......
......@@ -419,3 +419,15 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
def get_dtype(
dtype: Union[str, torch.dtype]
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment