Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a75c64d8
Commit
a75c64d8
authored
Aug 26, 2020
by
Lysandre
Browse files
Black 20 release
parent
e78c1103
Changes
191
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
129 additions
and
36 deletions
+129
-36
tests/test_modeling_transfo_xl.py
tests/test_modeling_transfo_xl.py
+2
-1
tests/test_modeling_xlm.py
tests/test_modeling_xlm.py
+2
-1
tests/test_modeling_xlnet.py
tests/test_modeling_xlnet.py
+22
-4
tests/test_pipelines.py
tests/test_pipelines.py
+36
-7
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+25
-8
tests/test_tokenization_fast.py
tests/test_tokenization_fast.py
+24
-7
tests/test_tokenization_mbart.py
tests/test_tokenization_mbart.py
+3
-1
tests/test_tokenization_reformer.py
tests/test_tokenization_reformer.py
+4
-2
tests/test_tokenization_t5.py
tests/test_tokenization_t5.py
+5
-1
tests/test_trainer.py
tests/test_trainer.py
+3
-1
utils/link_tester.py
utils/link_tester.py
+3
-3
No files found.
tests/test_modeling_transfo_xl.py
View file @
a75c64d8
...
@@ -32,7 +32,8 @@ if is_torch_available():
...
@@ -32,7 +32,8 @@ if is_torch_available():
class
TransfoXLModelTester
:
class
TransfoXLModelTester
:
def
__init__
(
def
__init__
(
self
,
parent
,
self
,
parent
,
):
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
14
self
.
batch_size
=
14
...
...
tests/test_modeling_xlm.py
View file @
a75c64d8
...
@@ -41,7 +41,8 @@ if is_torch_available():
...
@@ -41,7 +41,8 @@ if is_torch_available():
class
XLMModelTester
:
class
XLMModelTester
:
def
__init__
(
def
__init__
(
self
,
parent
,
self
,
parent
,
):
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
13
self
.
batch_size
=
13
...
...
tests/test_modeling_xlnet.py
View file @
a75c64d8
...
@@ -104,10 +104,20 @@ class XLNetModelTester:
...
@@ -104,10 +104,20 @@ class XLNetModelTester:
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
,
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
,
)
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
,)
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
,
)
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
sequence_labels
=
None
sequence_labels
=
None
...
@@ -217,7 +227,11 @@ class XLNetModelTester:
...
@@ -217,7 +227,11 @@ class XLNetModelTester:
# first forward pass
# first forward pass
causal_mask
=
torch
.
ones
(
causal_mask
=
torch
.
ones
(
input_ids_1
.
shape
[
0
],
input_ids_1
.
shape
[
1
],
input_ids_1
.
shape
[
1
],
dtype
=
torch
.
float
,
device
=
torch_device
,
input_ids_1
.
shape
[
0
],
input_ids_1
.
shape
[
1
],
input_ids_1
.
shape
[
1
],
dtype
=
torch
.
float
,
device
=
torch_device
,
)
)
causal_mask
=
torch
.
triu
(
causal_mask
,
diagonal
=
0
)
causal_mask
=
torch
.
triu
(
causal_mask
,
diagonal
=
0
)
outputs_cache
=
model
(
input_ids_1
,
use_cache
=
True
,
perm_mask
=
causal_mask
)
outputs_cache
=
model
(
input_ids_1
,
use_cache
=
True
,
perm_mask
=
causal_mask
)
...
@@ -363,7 +377,11 @@ class XLNetModelTester:
...
@@ -363,7 +377,11 @@ class XLNetModelTester:
total_loss
,
mems
=
result_with_labels
.
to_tuple
()
total_loss
,
mems
=
result_with_labels
.
to_tuple
()
result_with_labels
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,)
result_with_labels
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
total_loss
,
mems
=
result_with_labels
.
to_tuple
()
total_loss
,
mems
=
result_with_labels
.
to_tuple
()
...
...
tests/test_pipelines.py
View file @
a75c64d8
...
@@ -164,7 +164,8 @@ class MonoColumnInputTestCase(unittest.TestCase):
...
@@ -164,7 +164,8 @@ class MonoColumnInputTestCase(unittest.TestCase):
for
result
,
expect
in
zip
(
multi_result
,
expected_multi_result
):
for
result
,
expect
in
zip
(
multi_result
,
expected_multi_result
):
for
key
in
expected_check_keys
or
[]:
for
key
in
expected_check_keys
or
[]:
self
.
assertEqual
(
self
.
assertEqual
(
set
([
o
[
key
]
for
o
in
result
]),
set
([
o
[
key
]
for
o
in
expect
]),
set
([
o
[
key
]
for
o
in
result
]),
set
([
o
[
key
]
for
o
in
expect
]),
)
)
if
isinstance
(
multi_result
[
0
],
list
):
if
isinstance
(
multi_result
[
0
],
list
):
...
@@ -214,7 +215,13 @@ class MonoColumnInputTestCase(unittest.TestCase):
...
@@ -214,7 +215,13 @@ class MonoColumnInputTestCase(unittest.TestCase):
"This is"
# No mask_token is not supported
"This is"
# No mask_token is not supported
]
]
for
model_name
in
FILL_MASK_FINETUNED_MODELS
:
for
model_name
in
FILL_MASK_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
"fill-mask"
,
model
=
model_name
,
tokenizer
=
model_name
,
framework
=
"pt"
,
topk
=
2
,)
nlp
=
pipeline
(
task
=
"fill-mask"
,
model
=
model_name
,
tokenizer
=
model_name
,
framework
=
"pt"
,
topk
=
2
,
)
self
.
_test_mono_column_pipeline
(
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
mandatory_keys
,
invalid_inputs
,
expected_check_keys
=
[
"sequence"
]
nlp
,
valid_inputs
,
mandatory_keys
,
invalid_inputs
,
expected_check_keys
=
[
"sequence"
]
)
)
...
@@ -231,7 +238,13 @@ class MonoColumnInputTestCase(unittest.TestCase):
...
@@ -231,7 +238,13 @@ class MonoColumnInputTestCase(unittest.TestCase):
"This is"
# No mask_token is not supported
"This is"
# No mask_token is not supported
]
]
for
model_name
in
FILL_MASK_FINETUNED_MODELS
:
for
model_name
in
FILL_MASK_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
"fill-mask"
,
model
=
model_name
,
tokenizer
=
model_name
,
framework
=
"tf"
,
topk
=
2
,)
nlp
=
pipeline
(
task
=
"fill-mask"
,
model
=
model_name
,
tokenizer
=
model_name
,
framework
=
"tf"
,
topk
=
2
,
)
self
.
_test_mono_column_pipeline
(
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
mandatory_keys
,
invalid_inputs
,
expected_check_keys
=
[
"sequence"
]
nlp
,
valid_inputs
,
mandatory_keys
,
invalid_inputs
,
expected_check_keys
=
[
"sequence"
]
)
)
...
@@ -274,7 +287,13 @@ class MonoColumnInputTestCase(unittest.TestCase):
...
@@ -274,7 +287,13 @@ class MonoColumnInputTestCase(unittest.TestCase):
]
]
valid_targets
=
[
" Patrick"
,
" Clara"
]
valid_targets
=
[
" Patrick"
,
" Clara"
]
for
model_name
in
LARGE_FILL_MASK_FINETUNED_MODELS
:
for
model_name
in
LARGE_FILL_MASK_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
"fill-mask"
,
model
=
model_name
,
tokenizer
=
model_name
,
framework
=
"pt"
,
topk
=
2
,)
nlp
=
pipeline
(
task
=
"fill-mask"
,
model
=
model_name
,
tokenizer
=
model_name
,
framework
=
"pt"
,
topk
=
2
,
)
self
.
_test_mono_column_pipeline
(
self
.
_test_mono_column_pipeline
(
nlp
,
nlp
,
valid_inputs
,
valid_inputs
,
...
@@ -343,7 +362,12 @@ class MonoColumnInputTestCase(unittest.TestCase):
...
@@ -343,7 +362,12 @@ class MonoColumnInputTestCase(unittest.TestCase):
invalid_inputs
=
[
4
,
"<mask>"
]
invalid_inputs
=
[
4
,
"<mask>"
]
mandatory_keys
=
[
"summary_text"
]
mandatory_keys
=
[
"summary_text"
]
for
model_name
in
TF_SUMMARIZATION_FINETUNED_MODELS
:
for
model_name
in
TF_SUMMARIZATION_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
"summarization"
,
model
=
model_name
,
tokenizer
=
model_name
,
framework
=
"tf"
,)
nlp
=
pipeline
(
task
=
"summarization"
,
model
=
model_name
,
tokenizer
=
model_name
,
framework
=
"tf"
,
)
self
.
_test_mono_column_pipeline
(
self
.
_test_mono_column_pipeline
(
nlp
,
VALID_INPUTS
,
mandatory_keys
,
invalid_inputs
=
invalid_inputs
,
**
SUMMARIZATION_KWARGS
nlp
,
VALID_INPUTS
,
mandatory_keys
,
invalid_inputs
=
invalid_inputs
,
**
SUMMARIZATION_KWARGS
)
)
...
@@ -355,7 +379,10 @@ class MonoColumnInputTestCase(unittest.TestCase):
...
@@ -355,7 +379,10 @@ class MonoColumnInputTestCase(unittest.TestCase):
for
model_name
,
task
in
TRANSLATION_FINETUNED_MODELS
:
for
model_name
,
task
in
TRANSLATION_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
task
,
model
=
model_name
,
tokenizer
=
model_name
)
nlp
=
pipeline
(
task
=
task
,
model
=
model_name
,
tokenizer
=
model_name
)
self
.
_test_mono_column_pipeline
(
self
.
_test_mono_column_pipeline
(
nlp
,
VALID_INPUTS
,
mandatory_keys
,
invalid_inputs
,
nlp
,
VALID_INPUTS
,
mandatory_keys
,
invalid_inputs
,
)
)
@
require_tf
@
require_tf
...
@@ -655,7 +682,9 @@ class QAPipelineTests(unittest.TestCase):
...
@@ -655,7 +682,9 @@ class QAPipelineTests(unittest.TestCase):
class
NerPipelineTests
(
unittest
.
TestCase
):
class
NerPipelineTests
(
unittest
.
TestCase
):
def
_test_ner_pipeline
(
def
_test_ner_pipeline
(
self
,
nlp
:
Pipeline
,
output_keys
:
Iterable
[
str
],
self
,
nlp
:
Pipeline
,
output_keys
:
Iterable
[
str
],
):
):
ungrouped_ner_inputs
=
[
ungrouped_ner_inputs
=
[
...
...
tests/test_tokenization_common.py
View file @
a75c64d8
...
@@ -882,8 +882,7 @@ class TokenizerTesterMixin:
...
@@ -882,8 +882,7 @@ class TokenizerTesterMixin:
assert
encoded_sequence
==
padded_sequence_left
assert
encoded_sequence
==
padded_sequence_left
def
test_padding_to_max_length
(
self
):
def
test_padding_to_max_length
(
self
):
""" We keep this test for backward compatibility but it should be remove when `pad_to_max_length` will e deprecated
"""We keep this test for backward compatibility but it should be remove when `pad_to_max_length` will e deprecated"""
"""
tokenizers
=
self
.
get_tokenizers
(
do_lower_case
=
False
)
tokenizers
=
self
.
get_tokenizers
(
do_lower_case
=
False
)
for
tokenizer
in
tokenizers
:
for
tokenizer
in
tokenizers
:
with
self
.
subTest
(
f
"
{
tokenizer
.
__class__
.
__name__
}
"
):
with
self
.
subTest
(
f
"
{
tokenizer
.
__class__
.
__name__
}
"
):
...
@@ -972,7 +971,11 @@ class TokenizerTesterMixin:
...
@@ -972,7 +971,11 @@ class TokenizerTesterMixin:
# Test 'longest' and 'no_padding' don't do anything
# Test 'longest' and 'no_padding' don't do anything
tokenizer
.
padding_side
=
"right"
tokenizer
.
padding_side
=
"right"
not_padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
padding
=
True
,
return_special_tokens_mask
=
True
,)
not_padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
padding
=
True
,
return_special_tokens_mask
=
True
,
)
not_padded_input_ids
=
not_padded_sequence
[
"input_ids"
]
not_padded_input_ids
=
not_padded_sequence
[
"input_ids"
]
not_padded_special_tokens_mask
=
not_padded_sequence
[
"special_tokens_mask"
]
not_padded_special_tokens_mask
=
not_padded_sequence
[
"special_tokens_mask"
]
...
@@ -982,7 +985,11 @@ class TokenizerTesterMixin:
...
@@ -982,7 +985,11 @@ class TokenizerTesterMixin:
assert
input_ids
==
not_padded_input_ids
assert
input_ids
==
not_padded_input_ids
assert
special_tokens_mask
==
not_padded_special_tokens_mask
assert
special_tokens_mask
==
not_padded_special_tokens_mask
not_padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
padding
=
False
,
return_special_tokens_mask
=
True
,)
not_padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
padding
=
False
,
return_special_tokens_mask
=
True
,
)
not_padded_input_ids
=
not_padded_sequence
[
"input_ids"
]
not_padded_input_ids
=
not_padded_sequence
[
"input_ids"
]
not_padded_special_tokens_mask
=
not_padded_sequence
[
"special_tokens_mask"
]
not_padded_special_tokens_mask
=
not_padded_sequence
[
"special_tokens_mask"
]
...
@@ -1148,7 +1155,8 @@ class TokenizerTesterMixin:
...
@@ -1148,7 +1155,8 @@ class TokenizerTesterMixin:
)
)
for
key
in
encoded_sequences_batch_padded_1
.
keys
():
for
key
in
encoded_sequences_batch_padded_1
.
keys
():
self
.
assertListEqual
(
self
.
assertListEqual
(
encoded_sequences_batch_padded_1
[
key
],
encoded_sequences_batch_padded_2
[
key
],
encoded_sequences_batch_padded_1
[
key
],
encoded_sequences_batch_padded_2
[
key
],
)
)
# check 'no_padding' is unsensitive to a max length
# check 'no_padding' is unsensitive to a max length
...
@@ -1158,7 +1166,8 @@ class TokenizerTesterMixin:
...
@@ -1158,7 +1166,8 @@ class TokenizerTesterMixin:
)
)
for
key
in
encoded_sequences_batch_padded_1
.
keys
():
for
key
in
encoded_sequences_batch_padded_1
.
keys
():
self
.
assertListEqual
(
self
.
assertListEqual
(
encoded_sequences_batch_padded_1
[
key
],
encoded_sequences_batch_padded_2
[
key
],
encoded_sequences_batch_padded_1
[
key
],
encoded_sequences_batch_padded_2
[
key
],
)
)
def
test_added_token_serializable
(
self
):
def
test_added_token_serializable
(
self
):
...
@@ -1361,10 +1370,18 @@ class TokenizerTesterMixin:
...
@@ -1361,10 +1370,18 @@ class TokenizerTesterMixin:
if
tokenizer
.
pad_token_id
is
None
:
if
tokenizer
.
pad_token_id
is
None
:
self
.
assertRaises
(
self
.
assertRaises
(
ValueError
,
tokenizer
.
batch_encode_plus
,
sequences
,
padding
=
True
,
return_tensors
=
"pt"
,
ValueError
,
tokenizer
.
batch_encode_plus
,
sequences
,
padding
=
True
,
return_tensors
=
"pt"
,
)
)
self
.
assertRaises
(
self
.
assertRaises
(
ValueError
,
tokenizer
.
batch_encode_plus
,
sequences
,
padding
=
"longest"
,
return_tensors
=
"tf"
,
ValueError
,
tokenizer
.
batch_encode_plus
,
sequences
,
padding
=
"longest"
,
return_tensors
=
"tf"
,
)
)
else
:
else
:
pytorch_tensor
=
tokenizer
.
batch_encode_plus
(
sequences
,
padding
=
True
,
return_tensors
=
"pt"
)
pytorch_tensor
=
tokenizer
.
batch_encode_plus
(
sequences
,
padding
=
True
,
return_tensors
=
"pt"
)
...
...
tests/test_tokenization_fast.py
View file @
a75c64d8
...
@@ -228,7 +228,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
...
@@ -228,7 +228,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
def
assert_special_tokens_map_equal
(
self
,
tokenizer_r
,
tokenizer_p
):
def
assert_special_tokens_map_equal
(
self
,
tokenizer_r
,
tokenizer_p
):
# Assert the set of special tokens match.
# Assert the set of special tokens match.
self
.
assertSequenceEqual
(
self
.
assertSequenceEqual
(
tokenizer_p
.
special_tokens_map
.
items
(),
tokenizer_r
.
special_tokens_map
.
items
(),
tokenizer_p
.
special_tokens_map
.
items
(),
tokenizer_r
.
special_tokens_map
.
items
(),
)
)
def
assert_add_tokens
(
self
,
tokenizer_r
):
def
assert_add_tokens
(
self
,
tokenizer_r
):
...
@@ -544,18 +545,26 @@ class CommonFastTokenizerTest(unittest.TestCase):
...
@@ -544,18 +545,26 @@ class CommonFastTokenizerTest(unittest.TestCase):
assert_batch_padded_input_match
(
input_r
,
input_p
,
max_length
)
assert_batch_padded_input_match
(
input_r
,
input_p
,
max_length
)
input_r
=
tokenizer_r
.
batch_encode_plus
(
input_r
=
tokenizer_r
.
batch_encode_plus
(
[
"This is a simple input 1"
,
"This is a simple input 2"
],
max_length
=
max_length
,
padding
=
"max_length"
,
[
"This is a simple input 1"
,
"This is a simple input 2"
],
max_length
=
max_length
,
padding
=
"max_length"
,
)
)
input_p
=
tokenizer_p
.
batch_encode_plus
(
input_p
=
tokenizer_p
.
batch_encode_plus
(
[
"This is a simple input 1"
,
"This is a simple input 2"
],
max_length
=
max_length
,
padding
=
"max_length"
,
[
"This is a simple input 1"
,
"This is a simple input 2"
],
max_length
=
max_length
,
padding
=
"max_length"
,
)
)
assert_batch_padded_input_match
(
input_r
,
input_p
,
max_length
)
assert_batch_padded_input_match
(
input_r
,
input_p
,
max_length
)
input_r
=
tokenizer_r
.
batch_encode_plus
(
input_r
=
tokenizer_r
.
batch_encode_plus
(
[
"This is a simple input 1"
,
"This is a simple input 2"
],
max_length
=
max_length
,
padding
=
"longest"
,
[
"This is a simple input 1"
,
"This is a simple input 2"
],
max_length
=
max_length
,
padding
=
"longest"
,
)
)
input_p
=
tokenizer_p
.
batch_encode_plus
(
input_p
=
tokenizer_p
.
batch_encode_plus
(
[
"This is a simple input 1"
,
"This is a simple input 2"
],
max_length
=
max_length
,
padding
=
True
,
[
"This is a simple input 1"
,
"This is a simple input 2"
],
max_length
=
max_length
,
padding
=
True
,
)
)
assert_batch_padded_input_match
(
input_r
,
input_p
,
len
(
input_r
[
"input_ids"
][
0
]))
assert_batch_padded_input_match
(
input_r
,
input_p
,
len
(
input_r
[
"input_ids"
][
0
]))
...
@@ -865,7 +874,11 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
...
@@ -865,7 +874,11 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
# Simple input
# Simple input
self
.
assertRaises
(
self
.
assertRaises
(
ValueError
,
tokenizer_r
.
batch_encode_plus
,
s2
,
max_length
=
max_length
,
padding
=
"max_length"
,
ValueError
,
tokenizer_r
.
batch_encode_plus
,
s2
,
max_length
=
max_length
,
padding
=
"max_length"
,
)
)
# Pair input
# Pair input
...
@@ -876,7 +889,11 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
...
@@ -876,7 +889,11 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
# Pair input
# Pair input
self
.
assertRaises
(
self
.
assertRaises
(
ValueError
,
tokenizer_r
.
batch_encode_plus
,
p2
,
max_length
=
max_length
,
padding
=
"max_length"
,
ValueError
,
tokenizer_r
.
batch_encode_plus
,
p2
,
max_length
=
max_length
,
padding
=
"max_length"
,
)
)
...
...
tests/test_tokenization_mbart.py
View file @
a75c64d8
...
@@ -125,7 +125,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
...
@@ -125,7 +125,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
def
test_enro_tokenizer_prepare_seq2seq_batch
(
self
):
def
test_enro_tokenizer_prepare_seq2seq_batch
(
self
):
batch
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
batch
=
self
.
tokenizer
.
prepare_seq2seq_batch
(
self
.
src_text
,
tgt_texts
=
self
.
tgt_text
,
max_length
=
len
(
self
.
expected_src_tokens
),
self
.
src_text
,
tgt_texts
=
self
.
tgt_text
,
max_length
=
len
(
self
.
expected_src_tokens
),
)
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
...
...
tests/test_tokenization_reformer.py
View file @
a75c64d8
...
@@ -44,7 +44,8 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -44,7 +44,8 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self
.
assertListEqual
(
tokens
,
[
"▁This"
,
"▁is"
,
"▁a"
,
"▁t"
,
"est"
])
self
.
assertListEqual
(
tokens
,
[
"▁This"
,
"▁is"
,
"▁a"
,
"▁t"
,
"est"
])
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
285
,
46
,
10
,
170
,
382
],
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
285
,
46
,
10
,
170
,
382
],
)
)
tokens
=
tokenizer
.
tokenize
(
"I was born in 92000, and this is falsé."
)
tokens
=
tokenizer
.
tokenize
(
"I was born in 92000, and this is falsé."
)
...
@@ -76,7 +77,8 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -76,7 +77,8 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
)
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
self
.
assertListEqual
(
self
.
assertListEqual
(
ids
,
[
8
,
21
,
84
,
55
,
24
,
19
,
7
,
0
,
602
,
347
,
347
,
347
,
3
,
12
,
66
,
46
,
72
,
80
,
6
,
0
,
4
],
ids
,
[
8
,
21
,
84
,
55
,
24
,
19
,
7
,
0
,
602
,
347
,
347
,
347
,
3
,
12
,
66
,
46
,
72
,
80
,
6
,
0
,
4
],
)
)
back_tokens
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
back_tokens
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
...
...
tests/test_tokenization_t5.py
View file @
a75c64d8
...
@@ -126,7 +126,11 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -126,7 +126,11 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"Another summary."
,
"Another summary."
,
]
]
expected_src_tokens
=
[
71
,
307
,
8986
,
21
,
4505
,
51
,
52
,
1707
,
5
,
tokenizer
.
eos_token_id
]
expected_src_tokens
=
[
71
,
307
,
8986
,
21
,
4505
,
51
,
52
,
1707
,
5
,
tokenizer
.
eos_token_id
]
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
return_tensors
=
FRAMEWORK
,)
batch
=
tokenizer
.
prepare_seq2seq_batch
(
src_text
,
tgt_texts
=
tgt_text
,
return_tensors
=
FRAMEWORK
,
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
self
.
assertIsInstance
(
batch
,
BatchEncoding
)
result
=
list
(
batch
.
input_ids
.
numpy
()[
0
])
result
=
list
(
batch
.
input_ids
.
numpy
()[
0
])
self
.
assertListEqual
(
expected_src_tokens
,
result
)
self
.
assertListEqual
(
expected_src_tokens
,
result
)
...
...
tests/test_trainer.py
View file @
a75c64d8
...
@@ -275,7 +275,9 @@ class TrainerIntegrationTest(unittest.TestCase):
...
@@ -275,7 +275,9 @@ class TrainerIntegrationTest(unittest.TestCase):
MODEL_ID
=
"distilroberta-base"
MODEL_ID
=
"distilroberta-base"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_ID
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_ID
)
dataset
=
LineByLineTextDataset
(
dataset
=
LineByLineTextDataset
(
tokenizer
=
tokenizer
,
file_path
=
PATH_SAMPLE_TEXT
,
block_size
=
tokenizer
.
max_len_single_sentence
,
tokenizer
=
tokenizer
,
file_path
=
PATH_SAMPLE_TEXT
,
block_size
=
tokenizer
.
max_len_single_sentence
,
)
)
self
.
assertEqual
(
len
(
dataset
),
31
)
self
.
assertEqual
(
len
(
dataset
),
31
)
...
...
utils/link_tester.py
View file @
a75c64d8
...
@@ -18,7 +18,7 @@ S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
...
@@ -18,7 +18,7 @@ S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
def
list_python_files_in_repository
():
def
list_python_files_in_repository
():
"""
List all python files in the repository.
"""List all python files in the repository.
This function assumes that the script is executed in the root folder.
This function assumes that the script is executed in the root folder.
"""
"""
...
@@ -43,7 +43,7 @@ def find_all_links(file_paths):
...
@@ -43,7 +43,7 @@ def find_all_links(file_paths):
def
scan_code_for_links
(
source
):
def
scan_code_for_links
(
source
):
"""
Scans the file to find links using a regular expression.
"""Scans the file to find links using a regular expression.
Returns a list of links.
Returns a list of links.
"""
"""
with
open
(
source
,
"r"
)
as
content
:
with
open
(
source
,
"r"
)
as
content
:
...
@@ -55,7 +55,7 @@ def scan_code_for_links(source):
...
@@ -55,7 +55,7 @@ def scan_code_for_links(source):
def
check_all_links
(
links
):
def
check_all_links
(
links
):
"""
Check that the provided links are valid.
"""Check that the provided links are valid.
Links are considered valid if a HEAD request to the server
Links are considered valid if a HEAD request to the server
returns a 200 status code.
returns a 200 status code.
...
...
Prev
1
…
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment