Commit 0b896bd2 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

don't place quantized models .to(device)

parent 0d8e206c
......@@ -115,9 +115,10 @@ class HFLM(LM):
else torch.device("cpu")
)
else:
eval_logger.info(
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
)
if device != "cuda":
eval_logger.info(
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
)
# TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device
......@@ -204,7 +205,12 @@ class HFLM(LM):
self.model.tie_weights()
if gpus <= 1 and not parallelize:
# place model onto device, if not using HF Accelerate in any form
self.model.to(self.device)
try:
self.model.to(self.device)
except ValueError:
eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
......@@ -246,7 +252,12 @@ class HFLM(LM):
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
try:
self.model.to(self.device)
except ValueError:
eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
else:
self._model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment