Unverified Commit b382a09e authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Experimental loading of MLX files (#29511)

* Experimental loading of MLX files

* Update exception message

* Add test

* Style

* Use model from hf-internal-testing
parent 73a27345
......@@ -3297,9 +3297,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif metadata.get("format") == "flax":
from_flax = True
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
elif metadata.get("format") == "mlx":
# This is a mlx file, we assume weights are compatible with pt
pass
else:
raise ValueError(
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}"
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
)
from_pt = not (from_tf | from_flax)
......
......@@ -1256,6 +1256,26 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(len(logs.output), 1)
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
@require_safetensors
def test_model_from_pretrained_from_mlx(self):
from safetensors import safe_open
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-mistral-mlx")
self.assertIsNotNone(model)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
with safe_open(os.path.join(tmp_dir, "model.safetensors"), framework="pt") as f:
metadata = f.metadata()
self.assertEqual(metadata.get("format"), "pt")
new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
input_ids = torch.randint(100, 1000, (1, 10))
with torch.no_grad():
outputs = model(input_ids)
outputs_from_saved = new_model(input_ids)
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
@slow
@require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment