Commit dd2b2de9 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

re-add dtype flag

parent 1fb16673
...@@ -12,7 +12,7 @@ from lm_eval.api.model import LM ...@@ -12,7 +12,7 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from accelerate import Accelerator from accelerate import Accelerator
from itertools import islice from typing import Optional, Union
@register_model("hf-causal") @register_model("hf-causal")
...@@ -23,6 +23,7 @@ class HFLM(LM): ...@@ -23,6 +23,7 @@ class HFLM(LM):
pretrained="gpt2", pretrained="gpt2",
revision="main", revision="main",
low_cpu_mem_usage=None, low_cpu_mem_usage=None,
dtype: Optional[Union[str, torch.dtype]]="auto",
subfolder=None, subfolder=None,
tokenizer=None, tokenizer=None,
batch_size=1, batch_size=1,
...@@ -58,10 +59,15 @@ class HFLM(LM): ...@@ -58,10 +59,15 @@ class HFLM(LM):
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
self.model = transformers.AutoModelForCausalLM.from_pretrained( self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage pretrained,
revision=revision,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=utils.get_dtype(dtype),
).to(self.device) ).to(self.device)
self.model.eval() self.model.eval()
print(self.model.dtype)
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
revision=revision, revision=revision,
......
...@@ -419,3 +419,15 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -419,3 +419,15 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_dtype(
dtype: Union[str, torch.dtype]
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment