Commit ce214979 authored by baberabb's avatar baberabb
Browse files

set fp32 if device=mps

parent 2bb7ce3b
......@@ -114,6 +114,8 @@ class LM(abc.ABC):
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
if "device" in args and args["device"] == "mps":
args["dtype"] = "float32"
return cls(**args, **args2)
@property
......
......@@ -109,7 +109,7 @@ class HFLM(LM):
eval_logger.info(f"Using device '{device}'")
if device == "mps":
eval_logger.info(
"MPS is still in beta; add ,dtype=float32 to model_args ."
"MPS is still in beta and only supports float32; setting dtype to float32."
)
else:
eval_logger.info("Device not specified")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment