Commit 0b896bd2 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

don't place quantized models .to(device)

parent 0d8e206c
...@@ -115,6 +115,7 @@ class HFLM(LM): ...@@ -115,6 +115,7 @@ class HFLM(LM):
else torch.device("cpu") else torch.device("cpu")
) )
else: else:
if device != "cuda":
eval_logger.info( eval_logger.info(
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
) )
...@@ -204,7 +205,12 @@ class HFLM(LM): ...@@ -204,7 +205,12 @@ class HFLM(LM):
self.model.tie_weights() self.model.tie_weights()
if gpus <= 1 and not parallelize: if gpus <= 1 and not parallelize:
# place model onto device, if not using HF Accelerate in any form # place model onto device, if not using HF Accelerate in any form
try:
self.model.to(self.device) self.model.to(self.device)
except ValueError:
eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
...@@ -246,7 +252,12 @@ class HFLM(LM): ...@@ -246,7 +252,12 @@ class HFLM(LM):
if torch.cuda.is_available() if torch.cuda.is_available()
else torch.device("cpu") else torch.device("cpu")
) )
try:
self.model.to(self.device) self.model.to(self.device)
except ValueError:
eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
else: else:
self._model = accelerator.prepare(self.model) self._model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment