Unverified Commit 8cff2bea authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Extend `dtype` command line flag to `HFLM` (#523)

* allow for hf-causal to take dtype arg

* document this change
parent 4e94af6f
...@@ -46,12 +46,12 @@ python main.py \ ...@@ -46,12 +46,12 @@ python main.py \
--device cuda:0 --device cuda:0
``` ```
Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partially trained checkpoints: Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partially trained checkpoints, or to specify the datatype for running a model:
```bash ```bash
python main.py \ python main.py \
--model hf-causal \ --model hf-causal \
--model_args pretrained=EleutherAI/pythia-160m,revision=step100000 \ --model_args pretrained=EleutherAI/pythia-160m,revision=step100000,dtype="float" \
--tasks lambada_openai,hellaswag \ --tasks lambada_openai,hellaswag \
--device cuda:0 --device cuda:0
``` ```
......
import torch import torch
import transformers import transformers
from typing import Optional from typing import Optional, Union
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
def _get_dtype(
dtype: Union[str, torch.dtype]
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class HFLM(BaseLM): class HFLM(BaseLM):
def __init__( def __init__(
self, self,
...@@ -16,6 +28,7 @@ class HFLM(BaseLM): ...@@ -16,6 +28,7 @@ class HFLM(BaseLM):
batch_size=1, batch_size=1,
load_in_8bit: Optional[bool] = False, load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
dtype: Optional[Union[str, torch.dtype]]="auto",
): ):
super().__init__() super().__init__()
...@@ -46,6 +59,7 @@ class HFLM(BaseLM): ...@@ -46,6 +59,7 @@ class HFLM(BaseLM):
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision, revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment