Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e490020
Unverified
Commit
3e490020
authored
Oct 14, 2022
by
Sylvain Gugger
Committed by
GitHub
Oct 14, 2022
Browse files
Tokenizer from_pretrained should not use local files named like tokenizer files (#19626)
parent
8fcf5626
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
4 deletions
+23
-4
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+7
-4
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+16
-0
No files found.
src/transformers/tokenization_utils_base.py
View file @
3e490020
...
@@ -1670,6 +1670,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
...
@@ -1670,6 +1670,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
init_configuration
=
{}
init_configuration
=
{}
is_local
=
os
.
path
.
isdir
(
pretrained_model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
pretrained_model_name_or_path
)
single_file_id
=
None
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
if
len
(
cls
.
vocab_files_names
)
>
1
:
if
len
(
cls
.
vocab_files_names
)
>
1
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1684,6 +1685,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
...
@@ -1684,6 +1685,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
file_id
=
list
(
cls
.
vocab_files_names
.
keys
())[
0
]
file_id
=
list
(
cls
.
vocab_files_names
.
keys
())[
0
]
vocab_files
[
file_id
]
=
pretrained_model_name_or_path
vocab_files
[
file_id
]
=
pretrained_model_name_or_path
single_file_id
=
file_id
else
:
else
:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
additional_files_names
=
{
additional_files_names
=
{
...
@@ -1726,7 +1728,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
...
@@ -1726,7 +1728,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
for
file_id
,
file_path
in
vocab_files
.
items
():
for
file_id
,
file_path
in
vocab_files
.
items
():
if
file_path
is
None
:
if
file_path
is
None
:
resolved_vocab_files
[
file_id
]
=
None
resolved_vocab_files
[
file_id
]
=
None
elif
os
.
path
.
isfile
(
file_path
):
elif
single_file_id
==
file_id
:
if
os
.
path
.
isfile
(
file_path
):
resolved_vocab_files
[
file_id
]
=
file_path
resolved_vocab_files
[
file_id
]
=
file_path
elif
is_remote_url
(
file_path
):
elif
is_remote_url
(
file_path
):
resolved_vocab_files
[
file_id
]
=
download_url
(
file_path
,
proxies
=
proxies
)
resolved_vocab_files
[
file_id
]
=
download_url
(
file_path
,
proxies
=
proxies
)
...
...
tests/test_tokenization_common.py
View file @
3e490020
...
@@ -3920,6 +3920,22 @@ class TokenizerUtilTester(unittest.TestCase):
...
@@ -3920,6 +3920,22 @@ class TokenizerUtilTester(unittest.TestCase):
finally
:
finally
:
os
.
remove
(
tmp_file
)
os
.
remove
(
tmp_file
)
# Supporting this legacy load introduced a weird bug where the tokenizer would load local files if they are in
# the current folder and have the right name.
if
os
.
path
.
isfile
(
"tokenizer.json"
):
# We skip the test if the user has a `tokenizer.json` in this folder to avoid deleting it.
return
try
:
with
open
(
"tokenizer.json"
,
"wb"
)
as
f
:
http_get
(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/blob/main/tokenizer.json"
,
f
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
# The tiny random BERT has a vocab size of 1024, tiny gpt2 as a vocab size of 1000
self
.
assertEqual
(
tokenizer
.
vocab_size
,
1000
)
# Tokenizer should depend on the remote checkpoint, not the local tokenizer.json file.
finally
:
os
.
remove
(
"tokenizer.json"
)
def
test_legacy_load_from_url
(
self
):
def
test_legacy_load_from_url
(
self
):
# This test is for deprecated behavior and can be removed in v5
# This test is for deprecated behavior and can be removed in v5
_
=
AlbertTokenizer
.
from_pretrained
(
"https://huggingface.co/albert-base-v1/resolve/main/spiece.model"
)
_
=
AlbertTokenizer
.
from_pretrained
(
"https://huggingface.co/albert-base-v1/resolve/main/spiece.model"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment