Unverified Commit 78545d42 authored by Pasquale Minervini's avatar Pasquale Minervini Committed by GitHub
Browse files

self.device in huggingface.py line 210 treated as torch.device but might be a string (#1172)



* self.device in huggingface.py line 210

In huggingface.py line 210, self.device is str and does not have a "type" attribute

* Update huggingface.py

This handles both the case where `self.device` is a `torch.device` and a string

* Update huggingface.py

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 3f0a3611
......@@ -170,7 +170,7 @@ class HFLM(LM):
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
)
# TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device
self._device = torch.device(device)
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
......@@ -207,7 +207,7 @@ class HFLM(LM):
self.model.eval()
self.model.tie_weights()
if (gpus >= 1 or self.device.type == "mps") and isinstance(pretrained, str):
if (gpus >= 1 or str(self.device) == "mps") and isinstance(pretrained, str):
if not (parallelize or autogptq or ("device_map" in kwargs)):
# place model onto device requested manually,
# if not using HF Accelerate or device_map
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment