Unverified Commit 45c06510 authored by Alex Ishida's avatar Alex Ishida Committed by GitHub
Browse files

Add support for metadata format MLX (#29335)

Add support for loading safetensors files saved with metadata format mlx.
parent 923733c2
...@@ -504,7 +504,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ...@@ -504,7 +504,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
# Check format of the archive # Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f: with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
if metadata.get("format") not in ["pt", "tf", "flax"]: if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError( raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method." "you save your model with the `save_pretrained` method."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment