Commit 2bb7ce3b authored by baberabb's avatar baberabb
Browse files

add mps

parent 73b2149f
......@@ -99,7 +99,7 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu"]
["cuda", "cpu", "mps"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment