Unverified Commit 8cff2bea authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Extend `dtype` command line flag to `HFLM` (#523)

* allow for hf-causal to take dtype arg

* document this change
parent 4e94af6f
......@@ -46,12 +46,12 @@ python main.py \
--device cuda:0
```
Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partially trained checkpoints:
Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partially trained checkpoints, or to specify the datatype for running a model:
```bash
python main.py \
--model hf-causal \
--model_args pretrained=EleutherAI/pythia-160m,revision=step100000 \
--model_args pretrained=EleutherAI/pythia-160m,revision=step100000,dtype="float" \
--tasks lambada_openai,hellaswag \
--device cuda:0
```
......
import torch
import transformers
from typing import Optional
from typing import Optional, Union
from lm_eval.base import BaseLM
def _get_dtype(
dtype: Union[str, torch.dtype]
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
class HFLM(BaseLM):
def __init__(
self,
......@@ -16,6 +28,7 @@ class HFLM(BaseLM):
batch_size=1,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
dtype: Optional[Union[str, torch.dtype]]="auto",
):
super().__init__()
......@@ -46,6 +59,7 @@ class HFLM(BaseLM):
load_in_8bit=load_in_8bit,
low_cpu_mem_usage=low_cpu_mem_usage,
revision=revision,
torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code,
).to(self.device)
self.gpt2.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment