Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ae7bae8f
Unverified
Commit
ae7bae8f
authored
Jun 08, 2022
by
SaulLu
Committed by
GitHub
Jun 08, 2022
Browse files
fix `train_new_from_iterator` in the case of byte-level tokenizers (#17549)
parent
264128cb
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
56 additions
and
0 deletions
+56
-0
src/transformers/tokenization_utils_fast.py
src/transformers/tokenization_utils_fast.py
+3
-0
tests/models/bart/test_modeling_bart.py
tests/models/bart/test_modeling_bart.py
+1
-0
tests/models/blenderbot/test_modeling_blenderbot.py
tests/models/blenderbot/test_modeling_blenderbot.py
+1
-0
tests/models/deberta/test_modeling_deberta.py
tests/models/deberta/test_modeling_deberta.py
+5
-0
tests/models/gpt2/test_modeling_gpt2.py
tests/models/gpt2/test_modeling_gpt2.py
+5
-0
tests/models/gpt_neo/test_modeling_gpt_neo.py
tests/models/gpt_neo/test_modeling_gpt_neo.py
+5
-0
tests/models/gptj/test_modeling_gptj.py
tests/models/gptj/test_modeling_gptj.py
+5
-0
tests/models/ibert/test_modeling_ibert.py
tests/models/ibert/test_modeling_ibert.py
+5
-0
tests/models/led/test_modeling_led.py
tests/models/led/test_modeling_led.py
+1
-0
tests/models/longformer/test_modeling_longformer.py
tests/models/longformer/test_modeling_longformer.py
+5
-0
tests/models/roberta/test_modeling_roberta.py
tests/models/roberta/test_modeling_roberta.py
+5
-0
tests/models/yoso/test_modeling_yoso.py
tests/models/yoso/test_modeling_yoso.py
+5
-0
tests/tokenization/test_tokenization_fast.py
tests/tokenization/test_tokenization_fast.py
+10
-0
No files found.
src/transformers/tokenization_utils_fast.py
View file @
ae7bae8f
...
...
@@ -21,6 +21,7 @@ import os
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
tokenizers.pre_tokenizers
as
pre_tokenizers_fast
from
tokenizers
import
Encoding
as
EncodingFast
from
tokenizers
import
Tokenizer
as
TokenizerFast
from
tokenizers.decoders
import
Decoder
as
DecoderFast
...
...
@@ -699,6 +700,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
kwargs
[
"end_of_word_suffix"
]
=
tokenizer_json
[
"model"
][
"end_of_word_suffix"
]
if
tokenizer_json
[
"model"
][
"type"
]
==
"Unigram"
and
unk_token
is
not
None
:
kwargs
[
"unk_token"
]
=
unk_token
if
tokenizer_json
[
"pre_tokenizer"
][
"type"
]
==
"ByteLevel"
:
kwargs
[
"initial_alphabet"
]
=
pre_tokenizers_fast
.
ByteLevel
.
alphabet
()
trainer_class
=
MODEL_TO_TRAINER_MAPPING
[
tokenizer_json
[
"model"
][
"type"
]]
trainer
=
trainer_class
(
vocab_size
=
vocab_size
,
special_tokens
=
special_tokens
,
**
kwargs
)
...
...
tests/models/bart/test_modeling_bart.py
View file @
ae7bae8f
...
...
@@ -150,6 +150,7 @@ class BartModelTester:
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
max_position_embeddings
=
100
config
.
vocab_size
=
300
return
config
def
prepare_config_and_inputs_for_common
(
self
):
...
...
tests/models/blenderbot/test_modeling_blenderbot.py
View file @
ae7bae8f
...
...
@@ -140,6 +140,7 @@ class BlenderbotModelTester:
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
max_position_embeddings
=
100
config
.
vocab_size
=
300
return
config
def
prepare_config_and_inputs_for_common
(
self
):
...
...
tests/models/deberta/test_modeling_deberta.py
View file @
ae7bae8f
...
...
@@ -130,6 +130,11 @@ class DebertaModelTester(object):
pos_att_type
=
self
.
pos_att_type
,
)
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
vocab_size
=
300
return
config
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
.
loss
.
size
()),
[])
...
...
tests/models/gpt2/test_modeling_gpt2.py
View file @
ae7bae8f
...
...
@@ -166,6 +166,11 @@ class GPT2ModelTester:
reorder_and_upcast_attn
=
reorder_and_upcast_attn
,
)
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
vocab_size
=
300
return
config
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/models/gpt_neo/test_modeling_gpt_neo.py
View file @
ae7bae8f
...
...
@@ -151,6 +151,11 @@ class GPTNeoModelTester:
attention_types
=
self
.
attention_types
,
)
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
vocab_size
=
300
return
config
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/models/gptj/test_modeling_gptj.py
View file @
ae7bae8f
...
...
@@ -155,6 +155,11 @@ class GPTJModelTester:
rotary_dim
=
self
.
rotary_dim
,
)
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
vocab_size
=
300
return
config
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/models/ibert/test_modeling_ibert.py
View file @
ae7bae8f
...
...
@@ -116,6 +116,11 @@ class IBertModelTester:
quant_mode
=
True
,
)
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
vocab_size
=
300
return
config
def
create_and_check_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
...
...
tests/models/led/test_modeling_led.py
View file @
ae7bae8f
...
...
@@ -163,6 +163,7 @@ class LEDModelTester:
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
max_position_embeddings
=
100
config
.
vocab_size
=
300
return
config
def
prepare_config_and_inputs_for_common
(
self
):
...
...
tests/models/longformer/test_modeling_longformer.py
View file @
ae7bae8f
...
...
@@ -113,6 +113,11 @@ class LongformerModelTester:
attention_window
=
self
.
attention_window
,
)
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
vocab_size
=
300
return
config
def
create_and_check_attention_mask_determinism
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
...
...
tests/models/roberta/test_modeling_roberta.py
View file @
ae7bae8f
...
...
@@ -112,6 +112,11 @@ class RobertaModelTester:
initializer_range
=
self
.
initializer_range
,
)
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
vocab_size
=
300
return
config
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/models/yoso/test_modeling_yoso.py
View file @
ae7bae8f
...
...
@@ -126,6 +126,11 @@ class YosoModelTester:
initializer_range
=
self
.
initializer_range
,
)
def
get_pipeline_config
(
self
):
config
=
self
.
get_config
()
config
.
vocab_size
=
300
return
config
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/tokenization/test_tokenization_fast.py
View file @
ae7bae8f
...
...
@@ -39,6 +39,7 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
self
.
test_rust_tokenizer
=
True
model_paths
=
[
"robot-test/dummy-tokenizer-fast"
,
"robot-test/dummy-tokenizer-wordlevel"
]
self
.
bytelevel_bpe_model_name
=
"SaulLu/dummy-tokenizer-bytelevel-bpe"
# Inclusion of 2 tokenizers to test different types of models (Unigram and WordLevel for the moment)
self
.
tokenizers_list
=
[(
PreTrainedTokenizerFast
,
model_path
,
{})
for
model_path
in
model_paths
]
...
...
@@ -99,6 +100,15 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
shutil
.
rmtree
(
self
.
tmpdirname
)
self
.
tmpdirname
=
tmpdirname_orig
def
test_training_new_tokenizer_with_bytelevel
(
self
):
tokenizer
=
self
.
rust_tokenizer_class
.
from_pretrained
(
self
.
bytelevel_bpe_model_name
)
toy_text_iterator
=
(
"a"
for
_
in
range
(
1000
))
new_tokenizer
=
tokenizer
.
train_new_from_iterator
(
text_iterator
=
toy_text_iterator
,
length
=
1000
,
vocab_size
=
50
)
encoding_ids
=
new_tokenizer
.
encode
(
"a🤗"
)
self
.
assertEqual
(
encoding_ids
,
[
64
,
172
,
253
,
97
,
245
])
@
require_tokenizers
class
TokenizerVersioningTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment