Unverified Commit 7aa13c47 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #680 from baberabb/big-refactor_fixup

[Refactor] minor edits
parents 283cad70 43407f36
......@@ -114,6 +114,8 @@ class LM(abc.ABC):
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
if args2.get("device") == "mps" or args.get("device") == "mps":
args["dtype"] = "float32"
return cls(**args, **args2)
@property
......
......@@ -99,7 +99,7 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu"]
["cuda", "cpu", "mps"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
)
if device:
......@@ -107,6 +107,10 @@ class HFLM(LM):
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
if device == "mps":
eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32."
)
else:
eval_logger.info("Device not specified")
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
......
......@@ -58,7 +58,6 @@ def main():
ctx = task.fewshot_context(
doc=doc,
num_fewshot=args.num_fewshot,
rnd=rnd,
)
f.write(ctx + "\n")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment