Unverified Commit dd9fb896 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #447 from nikhilpinnaparaju/issue-437

Updated handling for device in lm_eval/models/gpt2.py
parents 602abceb 7376b0fd
......@@ -23,9 +23,8 @@ class HFLM(BaseLM):
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
if device:
if device not in ["cuda", "cpu"]:
device = int(device)
device_list = set(["cuda", "cpu"] + [f'cuda:{i}' for i in range(torch.cuda.device_count())])
if device and device in device_list:
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment