Commit 23eaa2db authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add device_map options for HFLM

parent 8762b07c
......@@ -15,6 +15,32 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
from typing import Optional, Union
def _get_accelerate_args(
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
) -> dict:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
max_memory = {}
if max_memory_per_gpu is not None:
max_memory_per_gpu_map = {
device_idx: max_memory_per_gpu
for device_idx in range(torch.cuda.device_count())
}
max_memory.update(max_memory_per_gpu_map)
if max_cpu_memory is not None:
max_memory["cpu"] = max_cpu_memory
args = {}
if max_memory:
args["max_memory"] = max_memory
args["device_map"] = device_map_option
args["offload_folder"] = offload_folder
return args
@register_model("hf-auto", "hf", "huggingface")
......@@ -39,6 +65,13 @@ class HFLM(LM):
subfolder=None,
tokenizer=None,
batch_size=1,
dtype: Optional[Union[str, torch.dtype]] = "auto",
# arguments used for splitting a model across GPUs naively.
use_accelerate: Optional[bool] = False,
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
):
super().__init__()
......@@ -65,8 +98,20 @@ class HFLM(LM):
self._rank = 0
self._world_size = 1
else:
elif not use_accelerate:
self._device = "cpu"
else:
self._device = device
model_kwargs = {}
if use_accelerate:
model_kwargs = _get_accelerate_args(
device_map_option,
max_memory_per_gpu,
max_cpu_memory,
offload_folder,
)
print(model_kwargs)
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
......@@ -88,10 +133,15 @@ class HFLM(LM):
]
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
).to(self.device)
pretrained,
revision=revision,
low_cpu_mem_usage=low_cpu_mem_usage,
**model_kwargs,
torch_dtype=utils.get_dtype(dtype),
) # .to(self.device)
# forever after, access self._model through self.model property
self.model.eval()
# TODO: call self.model.tie_weights() here
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
......@@ -105,8 +155,15 @@ class HFLM(LM):
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# if use_accelerate:
# if "lm_head" in self.model.hf_device_map:
# # `accelerate` can place `lm_head` weights on a different device than
# # the user specified one so we force `self._device` to be the same as
# # `lm_head`'s.
# self._device = self.model.hf_device_map["lm_head"]
print(self._device, self.model.hf_device_map)
# multigpu support with accelerate
if gpus > 1:
if gpus > 1 and not use_accelerate:
accelerator = Accelerator()
if gpus > accelerator.num_processes:
# TODO: make sure there's still never an edge case where we unintentionally default to CPU
......@@ -450,7 +507,7 @@ class HFLM(LM):
multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1
).cpu() # [batch, padding_length (inp or cont), vocab]
) # [batch, padding_length (inp or cont), vocab]
for (cache_key, _, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list
......@@ -470,7 +527,9 @@ class HFLM(LM):
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment