Commit 2d27f9e1 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add error handling for moving model to device

parent 4c08d72a
...@@ -65,9 +65,12 @@ class HFLM(BaseLM): ...@@ -65,9 +65,12 @@ class HFLM(BaseLM):
revision=revision, revision=revision,
torch_dtype=_get_dtype(dtype), torch_dtype=_get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
).to(self.device) ).eval()
self.gpt2.eval() if not load_in_8bit:
try:
self.gpt2.to(self.device)
except:
print("Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore.")
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
revision=revision, revision=revision,
......
...@@ -235,8 +235,11 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -235,8 +235,11 @@ class HuggingFaceAutoLM(BaseLM):
# the user specified one so we force `self._device` to be the same as # the user specified one so we force `self._device` to be the same as
# `lm_head`'s. # `lm_head`'s.
self._device = self.model.hf_device_map["lm_head"] self._device = self.model.hf_device_map["lm_head"]
if not use_accelerate: if not use_accelerate and not (load_in_4bit or load_in_8bit):
try:
self.model.to(self._device) self.model.to(self._device)
except:
print("Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore.")
def _create_auto_model( def _create_auto_model(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment