Unverified Commit 57c3b1a2 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

place device onto `mps` (#1133)

parent 04707a2d
......@@ -207,7 +207,7 @@ class HFLM(LM):
self.model.eval()
self.model.tie_weights()
if gpus >= 1 and isinstance(pretrained, str):
if (gpus >= 1 or self.device.type == "mps") and isinstance(pretrained, str):
if not (parallelize or autogptq or ("device_map" in kwargs)):
# place model onto device requested manually,
# if not using HF Accelerate or device_map
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment