Unverified Commit 6d023270 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow flax subfolder (#19902)

* add first generation tutorial

* [Flax] Add subfolder functionality

* [Flax] Add subfolder functionality

* up

* finish

* delete file and re-add test
parent 7a1c68a8
......@@ -636,6 +636,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_commit_hash=commit_hash,
......@@ -659,22 +660,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
elif from_pt and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
):
# Load from a sharded pytorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)):
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
# Load from a sharded Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
......@@ -685,7 +688,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
)
elif os.path.isfile(pretrained_model_name_or_path):
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
......@@ -786,6 +789,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
......
......@@ -1221,3 +1221,68 @@ class FlaxModelPushToHubTester(unittest.TestCase):
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def check_models_equal(model1, model2):
models_are_equal = True
flat_params_1 = flatten_dict(model1.params)
flat_params_2 = flatten_dict(model2.params)
for key in flat_params_1.keys():
if np.sum(np.abs(flat_params_1[key] - flat_params_2[key])) > 1e-4:
models_are_equal = False
return models_are_equal
@require_flax
class FlaxModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained_subfolder(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
model = FlaxBertModel(config)
subfolder = "bert"
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(os.path.join(tmp_dir, subfolder))
with self.assertRaises(OSError):
_ = FlaxBertModel.from_pretrained(tmp_dir)
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_subfolder_sharded(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
model = FlaxBertModel(config)
subfolder = "bert"
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
with self.assertRaises(OSError):
_ = FlaxBertModel.from_pretrained(tmp_dir)
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_hub_subfolder(self):
subfolder = "bert"
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
with self.assertRaises(OSError):
_ = FlaxBertModel.from_pretrained(model_id)
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
self.assertIsNotNone(model)
def test_model_from_pretrained_hub_subfolder_sharded(self):
subfolder = "bert"
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
with self.assertRaises(OSError):
_ = FlaxBertModel.from_pretrained(model_id)
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
self.assertIsNotNone(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment