Unverified Commit a69cbf4e authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Automatic safetensors conversion when lacking these files (#29390)

* Automatic safetensors conversion when lacking these files

* Remove debug

* Thread name

* Typo

* Ensure that raises do not affect the main thread
parent 9c5e5609
...@@ -29,6 +29,7 @@ import warnings ...@@ -29,6 +29,7 @@ import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, wraps from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from zipfile import is_zipfile from zipfile import is_zipfile
...@@ -3207,9 +3208,39 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3207,9 +3208,39 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
if resolved_archive_file is not None: if resolved_archive_file is not None:
is_sharded = True is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error if resolved_archive_file is not None:
# message. if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]:
# If the PyTorch file was found, check if there is a safetensors file on the repository
# If there is no safetensors file on the repositories, start an auto conversion
safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"token": token,
}
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"resume_download": resume_download,
"local_files_only": local_files_only,
"user_agent": user_agent,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
**has_file_kwargs,
}
if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
Thread(
target=auto_conversion,
args=(pretrained_model_name_or_path,),
kwargs=cached_file_kwargs,
name="Thread-autoconversion",
).start()
else:
# Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
# We try those to give a helpful error message.
has_file_kwargs = { has_file_kwargs = {
"revision": revision, "revision": revision,
"proxies": proxies, "proxies": proxies,
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import os.path import os.path
import sys import sys
import tempfile import tempfile
import threading
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
import uuid import uuid
...@@ -1428,7 +1429,7 @@ class ModelOnTheFlyConversionTester(unittest.TestCase): ...@@ -1428,7 +1429,7 @@ class ModelOnTheFlyConversionTester(unittest.TestCase):
bot_opened_pr_title = None bot_opened_pr_title = None
for discussion in discussions: for discussion in discussions:
if discussion.author == "SFconvertBot": if discussion.author == "SFconvertbot":
bot_opened_pr = True bot_opened_pr = True
bot_opened_pr_title = discussion.title bot_opened_pr_title = discussion.title
...@@ -1451,6 +1452,51 @@ class ModelOnTheFlyConversionTester(unittest.TestCase): ...@@ -1451,6 +1452,51 @@ class ModelOnTheFlyConversionTester(unittest.TestCase):
with self.assertRaises(EnvironmentError): with self.assertRaises(EnvironmentError):
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch") BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")
def test_absence_of_safetensors_triggers_conversion(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
initial_model = BertModel(config)
# Push a model on `main`
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
# Download the model that doesn't have safetensors
BertModel.from_pretrained(self.repo_name, token=self.token)
for thread in threading.enumerate():
if thread.name == "Thread-autoconversion":
thread.join(timeout=10)
with self.subTest("PR was open with the safetensors account"):
discussions = self.api.get_repo_discussions(self.repo_name)
bot_opened_pr = None
bot_opened_pr_title = None
for discussion in discussions:
if discussion.author == "SFconvertbot":
bot_opened_pr = True
bot_opened_pr_title = discussion.title
self.assertTrue(bot_opened_pr)
self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")
@mock.patch("transformers.safetensors_conversion.spawn_conversion")
def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock):
spawn_conversion_mock.side_effect = HTTPError()
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
initial_model = BertModel(config)
# Push a model on `main`
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
# The auto conversion is mocked to always raise; ensure that it doesn't raise in the main thread
BertModel.from_pretrained(self.repo_name, token=self.token)
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment