Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ca485e56
Unverified
Commit
ca485e56
authored
Sep 16, 2022
by
Sylvain Gugger
Committed by
GitHub
Sep 16, 2022
Browse files
Add tests for legacy load by url and fix bugs (#19078)
parent
ae219532
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
62 additions
and
5 deletions
+62
-5
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+1
-1
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+1
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+1
-1
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+1
-1
tests/test_configuration_common.py
tests/test_configuration_common.py
+6
-0
tests/test_feature_extraction_common.py
tests/test_feature_extraction_common.py
+6
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+21
-0
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+19
-0
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+6
-1
No files found.
src/transformers/modeling_flax_utils.py
View file @
ca485e56
...
@@ -680,7 +680,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
...
@@ -680,7 +680,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
archive_file
=
pretrained_model_name_or_path
archive_file
=
pretrained_model_name_or_path
is_local
=
True
is_local
=
True
elif
is_remote_url
(
pretrained_model_name_or_path
):
elif
is_remote_url
(
pretrained_model_name_or_path
):
archive_fil
e
=
pretrained_model_name_or_path
filenam
e
=
pretrained_model_name_or_path
resolved_archive_file
=
download_url
(
pretrained_model_name_or_path
)
resolved_archive_file
=
download_url
(
pretrained_model_name_or_path
)
else
:
else
:
filename
=
WEIGHTS_NAME
if
from_pt
else
FLAX_WEIGHTS_NAME
filename
=
WEIGHTS_NAME
if
from_pt
else
FLAX_WEIGHTS_NAME
...
...
src/transformers/modeling_tf_utils.py
View file @
ca485e56
...
@@ -2418,7 +2418,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -2418,7 +2418,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
archive_file
=
pretrained_model_name_or_path
+
".index"
archive_file
=
pretrained_model_name_or_path
+
".index"
is_local
=
True
is_local
=
True
elif
is_remote_url
(
pretrained_model_name_or_path
):
elif
is_remote_url
(
pretrained_model_name_or_path
):
archive_fil
e
=
pretrained_model_name_or_path
filenam
e
=
pretrained_model_name_or_path
resolved_archive_file
=
download_url
(
pretrained_model_name_or_path
)
resolved_archive_file
=
download_url
(
pretrained_model_name_or_path
)
else
:
else
:
# set correct filename
# set correct filename
...
...
src/transformers/modeling_utils.py
View file @
ca485e56
...
@@ -2005,7 +2005,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2005,7 +2005,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
archive_file
=
os
.
path
.
join
(
subfolder
,
pretrained_model_name_or_path
+
".index"
)
archive_file
=
os
.
path
.
join
(
subfolder
,
pretrained_model_name_or_path
+
".index"
)
is_local
=
True
is_local
=
True
elif
is_remote_url
(
pretrained_model_name_or_path
):
elif
is_remote_url
(
pretrained_model_name_or_path
):
archive_fil
e
=
pretrained_model_name_or_path
filenam
e
=
pretrained_model_name_or_path
resolved_archive_file
=
download_url
(
pretrained_model_name_or_path
)
resolved_archive_file
=
download_url
(
pretrained_model_name_or_path
)
else
:
else
:
# set correct filename
# set correct filename
...
...
src/transformers/tokenization_utils_base.py
View file @
ca485e56
...
@@ -1670,7 +1670,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
...
@@ -1670,7 +1670,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
init_configuration
=
{}
init_configuration
=
{}
is_local
=
os
.
path
.
isdir
(
pretrained_model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
pretrained_model_name_or_path
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
)
:
if
len
(
cls
.
vocab_files_names
)
>
1
:
if
len
(
cls
.
vocab_files_names
)
>
1
:
raise
ValueError
(
raise
ValueError
(
f
"Calling
{
cls
.
__name__
}
.from_pretrained() with the path to a single file or url is not "
f
"Calling
{
cls
.
__name__
}
.from_pretrained() with the path to a single file or url is not "
...
...
tests/test_configuration_common.py
View file @
ca485e56
...
@@ -360,6 +360,12 @@ class ConfigTestUtils(unittest.TestCase):
...
@@ -360,6 +360,12 @@ class ConfigTestUtils(unittest.TestCase):
# This check we did call the fake head request
# This check we did call the fake head request
mock_head
.
assert_called
()
mock_head
.
assert_called
()
def
test_legacy_load_from_url
(
self
):
# This test is for deprecated behavior and can be removed in v5
_
=
BertConfig
.
from_pretrained
(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/config.json"
)
class
ConfigurationVersioningTest
(
unittest
.
TestCase
):
class
ConfigurationVersioningTest
(
unittest
.
TestCase
):
def
test_local_versioning
(
self
):
def
test_local_versioning
(
self
):
...
...
tests/test_feature_extraction_common.py
View file @
ca485e56
...
@@ -182,6 +182,12 @@ class FeatureExtractorUtilTester(unittest.TestCase):
...
@@ -182,6 +182,12 @@ class FeatureExtractorUtilTester(unittest.TestCase):
# This check we did call the fake head request
# This check we did call the fake head request
mock_head
.
assert_called
()
mock_head
.
assert_called
()
def
test_legacy_load_from_url
(
self
):
# This test is for deprecated behavior and can be removed in v5
_
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"https://huggingface.co/hf-internal-testing/tiny-random-wav2vec2/resolve/main/preprocessor_config.json"
)
@
is_staging_test
@
is_staging_test
class
FeatureExtractorPushToHubTester
(
unittest
.
TestCase
):
class
FeatureExtractorPushToHubTester
(
unittest
.
TestCase
):
...
...
tests/test_modeling_common.py
View file @
ca485e56
...
@@ -33,6 +33,7 @@ import numpy as np
...
@@ -33,6 +33,7 @@ import numpy as np
import
transformers
import
transformers
from
huggingface_hub
import
HfFolder
,
delete_repo
,
set_access_token
from
huggingface_hub
import
HfFolder
,
delete_repo
,
set_access_token
from
huggingface_hub.file_download
import
http_get
from
requests.exceptions
import
HTTPError
from
requests.exceptions
import
HTTPError
from
transformers
import
(
from
transformers
import
(
AutoConfig
,
AutoConfig
,
...
@@ -2949,6 +2950,26 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -2949,6 +2950,26 @@ class ModelUtilsTest(TestCasePlus):
# This check we did call the fake head request
# This check we did call the fake head request
mock_head
.
assert_called
()
mock_head
.
assert_called
()
def
test_load_from_one_file
(
self
):
try
:
tmp_file
=
tempfile
.
mktemp
()
with
open
(
tmp_file
,
"wb"
)
as
f
:
http_get
(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin"
,
f
)
config
=
BertConfig
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
BertModel
.
from_pretrained
(
tmp_file
,
config
=
config
)
finally
:
os
.
remove
(
tmp_file
)
def
test_legacy_load_from_url
(
self
):
# This test is for deprecated behavior and can be removed in v5
config
=
BertConfig
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
BertModel
.
from_pretrained
(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin"
,
config
=
config
)
@
require_torch
@
require_torch
@
is_staging_test
@
is_staging_test
...
...
tests/test_modeling_tf_common.py
View file @
ca485e56
...
@@ -30,6 +30,7 @@ from typing import List, Tuple, get_type_hints
...
@@ -30,6 +30,7 @@ from typing import List, Tuple, get_type_hints
from
datasets
import
Dataset
from
datasets
import
Dataset
from
huggingface_hub
import
HfFolder
,
Repository
,
delete_repo
,
set_access_token
from
huggingface_hub
import
HfFolder
,
Repository
,
delete_repo
,
set_access_token
from
huggingface_hub.file_download
import
http_get
from
requests.exceptions
import
HTTPError
from
requests.exceptions
import
HTTPError
from
transformers
import
is_tf_available
,
is_torch_available
from
transformers
import
is_tf_available
,
is_torch_available
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
...
@@ -1927,6 +1928,24 @@ class UtilsFunctionsTest(unittest.TestCase):
...
@@ -1927,6 +1928,24 @@ class UtilsFunctionsTest(unittest.TestCase):
# This check we did call the fake head request
# This check we did call the fake head request
mock_head
.
assert_called
()
mock_head
.
assert_called
()
def
test_load_from_one_file
(
self
):
try
:
tmp_file
=
tempfile
.
mktemp
()
with
open
(
tmp_file
,
"wb"
)
as
f
:
http_get
(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5"
,
f
)
config
=
BertConfig
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
TFBertModel
.
from_pretrained
(
tmp_file
,
config
=
config
)
finally
:
os
.
remove
(
tmp_file
)
def
test_legacy_load_from_url
(
self
):
# This test is for deprecated behavior and can be removed in v5
config
=
BertConfig
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
TFBertModel
.
from_pretrained
(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5"
,
config
=
config
)
# tests whether the unpack_inputs function behaves as expected
# tests whether the unpack_inputs function behaves as expected
def
test_unpack_inputs
(
self
):
def
test_unpack_inputs
(
self
):
class
DummyModel
:
class
DummyModel
:
...
...
tests/test_tokenization_common.py
View file @
ca485e56
...
@@ -3891,15 +3891,20 @@ class TokenizerUtilTester(unittest.TestCase):
...
@@ -3891,15 +3891,20 @@ class TokenizerUtilTester(unittest.TestCase):
mock_head
.
assert_called
()
mock_head
.
assert_called
()
def
test_legacy_load_from_one_file
(
self
):
def
test_legacy_load_from_one_file
(
self
):
# This test is for deprecated behavior and can be removed in v5
try
:
try
:
tmp_file
=
tempfile
.
mktemp
()
tmp_file
=
tempfile
.
mktemp
()
with
open
(
tmp_file
,
"wb"
)
as
f
:
with
open
(
tmp_file
,
"wb"
)
as
f
:
http_get
(
"https://huggingface.co/albert-base-v1/resolve/main/spiece.model"
,
f
)
http_get
(
"https://huggingface.co/albert-base-v1/resolve/main/spiece.model"
,
f
)
AlbertTokenizer
.
from_pretrained
(
tmp_file
)
_
=
AlbertTokenizer
.
from_pretrained
(
tmp_file
)
finally
:
finally
:
os
.
remove
(
tmp_file
)
os
.
remove
(
tmp_file
)
def
test_legacy_load_from_url
(
self
):
# This test is for deprecated behavior and can be removed in v5
_
=
AlbertTokenizer
.
from_pretrained
(
"https://huggingface.co/albert-base-v1/resolve/main/spiece.model"
)
@
is_staging_test
@
is_staging_test
class
TokenizerPushToHubTester
(
unittest
.
TestCase
):
class
TokenizerPushToHubTester
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment