"vscode:/vscode.git/clone" did not exist on "c81c01b2ec34c189911c32631b980d8d065dda26"
Commit 7376b0fd authored by Nikhil Pinnaparaju's avatar Nikhil Pinnaparaju
Browse files

Updated handling for device in lm_eval/models/gpt2.py

parent 602abceb
...@@ -23,9 +23,8 @@ class HFLM(BaseLM): ...@@ -23,9 +23,8 @@ class HFLM(BaseLM):
assert isinstance(pretrained, str) assert isinstance(pretrained, str)
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
if device: device_list = set(["cuda", "cpu"] + [f'cuda:{i}' for i in range(torch.cuda.device_count())])
if device not in ["cuda", "cpu"]: if device and device in device_list:
device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
print(f"Using device '{device}'") print(f"Using device '{device}'")
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment