Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ed71c21d
"...git@developer.sourcefind.cn:chenpangpang/diffusers.git" did not exist on "3fca52022fe0ea9aaf0a0ea8a0fc13308bf69a9f"
Unverified
Commit
ed71c21d
authored
Sep 09, 2020
by
Julien Chaumond
Committed by
GitHub
Sep 09, 2020
Browse files
[from_pretrained] Allow tokenizer_type ≠ model_type (#6995)
parent
03e363f9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
1 deletion
+28
-1
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+1
-0
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+1
-0
src/transformers/tokenization_auto.py
src/transformers/tokenization_auto.py
+11
-0
tests/test_tokenization_auto.py
tests/test_tokenization_auto.py
+15
-1
No files found.
src/transformers/configuration_utils.py
View file @
ed71c21d
...
@@ -190,6 +190,7 @@ class PretrainedConfig(object):
...
@@ -190,6 +190,7 @@ class PretrainedConfig(object):
self
.
num_labels
=
kwargs
.
pop
(
"num_labels"
,
2
)
self
.
num_labels
=
kwargs
.
pop
(
"num_labels"
,
2
)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self
.
tokenizer_class
=
kwargs
.
pop
(
"tokenizer_class"
,
None
)
self
.
prefix
=
kwargs
.
pop
(
"prefix"
,
None
)
self
.
prefix
=
kwargs
.
pop
(
"prefix"
,
None
)
self
.
bos_token_id
=
kwargs
.
pop
(
"bos_token_id"
,
None
)
self
.
bos_token_id
=
kwargs
.
pop
(
"bos_token_id"
,
None
)
self
.
pad_token_id
=
kwargs
.
pop
(
"pad_token_id"
,
None
)
self
.
pad_token_id
=
kwargs
.
pop
(
"pad_token_id"
,
None
)
...
...
src/transformers/testing_utils.py
View file @
ed71c21d
...
@@ -15,6 +15,7 @@ from .file_utils import _tf_available, _torch_available, _torch_tpu_available
...
@@ -15,6 +15,7 @@ from .file_utils import _tf_available, _torch_available, _torch_tpu_available
SMALL_MODEL_IDENTIFIER
=
"julien-c/bert-xsmall-dummy"
SMALL_MODEL_IDENTIFIER
=
"julien-c/bert-xsmall-dummy"
DUMMY_UNKWOWN_IDENTIFIER
=
"julien-c/dummy-unknown"
DUMMY_UNKWOWN_IDENTIFIER
=
"julien-c/dummy-unknown"
DUMMY_DIFF_TOKENIZER_IDENTIFIER
=
"julien-c/dummy-diff-tokenizer"
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
...
...
src/transformers/tokenization_auto.py
View file @
ed71c21d
...
@@ -222,6 +222,17 @@ class AutoTokenizer:
...
@@ -222,6 +222,17 @@ class AutoTokenizer:
return
BertJapaneseTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
BertJapaneseTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
use_fast
=
kwargs
.
pop
(
"use_fast"
,
False
)
use_fast
=
kwargs
.
pop
(
"use_fast"
,
False
)
if
config
.
tokenizer_class
is
not
None
:
if
use_fast
and
not
config
.
tokenizer_class
.
endswith
(
"Fast"
):
tokenizer_class_candidate
=
f
"
{
config
.
tokenizer_class
}
Fast"
else
:
tokenizer_class_candidate
=
config
.
tokenizer_class
tokenizer_class
=
globals
().
get
(
tokenizer_class_candidate
)
if
tokenizer_class
is
None
:
raise
ValueError
(
"Tokenizer class {} does not exist or is not currently imported."
)
return
tokenizer_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
for
config_class
,
(
tokenizer_class_py
,
tokenizer_class_fast
)
in
TOKENIZER_MAPPING
.
items
():
for
config_class
,
(
tokenizer_class_py
,
tokenizer_class_fast
)
in
TOKENIZER_MAPPING
.
items
():
if
isinstance
(
config
,
config_class
):
if
isinstance
(
config
,
config_class
):
if
tokenizer_class_fast
and
use_fast
:
if
tokenizer_class_fast
and
use_fast
:
...
...
tests/test_tokenization_auto.py
View file @
ed71c21d
...
@@ -27,7 +27,13 @@ from transformers import (
...
@@ -27,7 +27,13 @@ from transformers import (
RobertaTokenizer
,
RobertaTokenizer
,
RobertaTokenizerFast
,
RobertaTokenizerFast
,
)
)
from
transformers.testing_utils
import
DUMMY_UNKWOWN_IDENTIFIER
,
SMALL_MODEL_IDENTIFIER
# noqa: F401
from
transformers.configuration_auto
import
AutoConfig
from
transformers.configuration_roberta
import
RobertaConfig
from
transformers.testing_utils
import
(
DUMMY_DIFF_TOKENIZER_IDENTIFIER
,
DUMMY_UNKWOWN_IDENTIFIER
,
SMALL_MODEL_IDENTIFIER
,
)
from
transformers.tokenization_auto
import
TOKENIZER_MAPPING
from
transformers.tokenization_auto
import
TOKENIZER_MAPPING
...
@@ -56,6 +62,14 @@ class AutoTokenizerTest(unittest.TestCase):
...
@@ -56,6 +62,14 @@ class AutoTokenizerTest(unittest.TestCase):
self
.
assertIsInstance
(
tokenizer
,
(
RobertaTokenizer
,
RobertaTokenizerFast
))
self
.
assertIsInstance
(
tokenizer
,
(
RobertaTokenizer
,
RobertaTokenizerFast
))
self
.
assertEqual
(
tokenizer
.
vocab_size
,
20
)
self
.
assertEqual
(
tokenizer
.
vocab_size
,
20
)
def
test_tokenizer_from_tokenizer_class
(
self
):
config
=
AutoConfig
.
from_pretrained
(
DUMMY_DIFF_TOKENIZER_IDENTIFIER
)
self
.
assertIsInstance
(
config
,
RobertaConfig
)
# Check that tokenizer_type ≠ model_type
tokenizer
=
AutoTokenizer
.
from_pretrained
(
DUMMY_DIFF_TOKENIZER_IDENTIFIER
,
config
=
config
)
self
.
assertIsInstance
(
tokenizer
,
(
BertTokenizer
,
BertTokenizerFast
))
self
.
assertEqual
(
tokenizer
.
vocab_size
,
12
)
def
test_tokenizer_identifier_with_correct_config
(
self
):
def
test_tokenizer_identifier_with_correct_config
(
self
):
for
tokenizer_class
in
[
BertTokenizer
,
BertTokenizerFast
,
AutoTokenizer
]:
for
tokenizer_class
in
[
BertTokenizer
,
BertTokenizerFast
,
AutoTokenizer
]:
tokenizer
=
tokenizer_class
.
from_pretrained
(
"wietsedv/bert-base-dutch-cased"
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
"wietsedv/bert-base-dutch-cased"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment