Commit 43407f36 authored by baberabb's avatar baberabb
Browse files

set fp32 if device=mps

parent ce214979
......@@ -114,7 +114,7 @@ class LM(abc.ABC):
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
if "device" in args and args["device"] == "mps":
if args2.get("device") == "mps" or args.get("device") == "mps":
args["dtype"] = "float32"
return cls(**args, **args2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment