Unverified Commit 7aa13c47 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #680 from baberabb/big-refactor_fixup

[Refactor] minor edits
parents 283cad70 43407f36
...@@ -114,6 +114,8 @@ class LM(abc.ABC): ...@@ -114,6 +114,8 @@ class LM(abc.ABC):
additional_config = {} if additional_config is None else additional_config additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
if args2.get("device") == "mps" or args.get("device") == "mps":
args["dtype"] = "float32"
return cls(**args, **args2) return cls(**args, **args2)
@property @property
......
...@@ -99,7 +99,7 @@ class HFLM(LM): ...@@ -99,7 +99,7 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
["cuda", "cpu"] ["cuda", "cpu", "mps"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
) )
if device: if device:
...@@ -107,6 +107,10 @@ class HFLM(LM): ...@@ -107,6 +107,10 @@ class HFLM(LM):
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
if device == "mps":
eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32."
)
else: else:
eval_logger.info("Device not specified") eval_logger.info("Device not specified")
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
......
...@@ -58,7 +58,6 @@ def main(): ...@@ -58,7 +58,6 @@ def main():
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
rnd=rnd,
) )
f.write(ctx + "\n") f.write(ctx + "\n")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment