Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
54abc67a
Unverified
Commit
54abc67a
authored
Dec 22, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 22, 2019
Browse files
Merge pull request #2255 from aaugustin/implement-best-practices
Implement some Python best practices
parents
645713e2
c11b3e29
Changes
205
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1365 additions
and
938 deletions
+1365
-938
transformers/tests/modeling_tf_t5_test.py
transformers/tests/modeling_tf_t5_test.py
+52
-50
transformers/tests/modeling_tf_transfo_xl_test.py
transformers/tests/modeling_tf_transfo_xl_test.py
+53
-53
transformers/tests/modeling_tf_xlm_test.py
transformers/tests/modeling_tf_xlm_test.py
+156
-104
transformers/tests/modeling_tf_xlnet_test.py
transformers/tests/modeling_tf_xlnet_test.py
+187
-107
transformers/tests/modeling_transfo_xl_test.py
transformers/tests/modeling_transfo_xl_test.py
+50
-52
transformers/tests/modeling_xlm_test.py
transformers/tests/modeling_xlm_test.py
+196
-126
transformers/tests/modeling_xlnet_test.py
transformers/tests/modeling_xlnet_test.py
+238
-126
transformers/tests/optimization_test.py
transformers/tests/optimization_test.py
+26
-20
transformers/tests/optimization_tf_test.py
transformers/tests/optimization_tf_test.py
+7
-8
transformers/tests/pipelines_test.py
transformers/tests/pipelines_test.py
+62
-62
transformers/tests/tokenization_albert_test.py
transformers/tests/tokenization_albert_test.py
+21
-15
transformers/tests/tokenization_auto_test.py
transformers/tests/tokenization_auto_test.py
+11
-8
transformers/tests/tokenization_bert_japanese_test.py
transformers/tests/tokenization_bert_japanese_test.py
+66
-66
transformers/tests/tokenization_bert_test.py
transformers/tests/tokenization_bert_test.py
+56
-48
transformers/tests/tokenization_ctrl_test.py
transformers/tests/tokenization_ctrl_test.py
+12
-12
transformers/tests/tokenization_distilbert_test.py
transformers/tests/tokenization_distilbert_test.py
+6
-7
transformers/tests/tokenization_gpt2_test.py
transformers/tests/tokenization_gpt2_test.py
+32
-13
transformers/tests/tokenization_openai_test.py
transformers/tests/tokenization_openai_test.py
+31
-14
transformers/tests/tokenization_roberta_test.py
transformers/tests/tokenization_roberta_test.py
+37
-20
transformers/tests/tokenization_t5_test.py
transformers/tests/tokenization_t5_test.py
+66
-27
No files found.
transformers/tests/modeling_tf_t5_test.py
View file @
54abc67a
...
...
@@ -12,23 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
unittest
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
transformers
import
T5Config
,
is_tf_available
from
.configuration_common_test
import
ConfigTester
from
.modeling_tf_common_test
import
TFCommonTestCases
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
transformers
import
T5Config
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers.modeling_tf_t5
import
(
TFT5Model
,
TFT5WithLMHeadModel
,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
transformers.modeling_tf_t5
import
TFT5Model
,
TFT5WithLMHeadModel
@
require_tf
...
...
@@ -38,8 +34,8 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes
=
(
TFT5Model
,
TFT5WithLMHeadModel
)
if
is_tf_available
()
else
()
class
TFT5ModelTester
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
...
...
@@ -95,53 +91,58 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
num_heads
=
self
.
num_attention_heads
,
relative_attention_num_buckets
=
self
.
relative_attention_num_buckets
,
dropout_rate
=
self
.
dropout_rate
,
initializer_factor
=
self
.
initializer_factor
)
initializer_factor
=
self
.
initializer_factor
,
)
return
(
config
,
input_ids
,
input_mask
,
token_labels
)
def
create_and_check_t5_model
(
self
,
config
,
input_ids
,
input_mask
,
token_labels
):
model
=
TFT5Model
(
config
=
config
)
inputs
=
{
'encoder_input_ids'
:
input_ids
,
'decoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
inputs
=
{
"encoder_input_ids"
:
input_ids
,
"decoder_input_ids"
:
input_ids
,
"decoder_attention_mask"
:
input_mask
,
}
encoder_output
,
decoder_output
=
model
(
inputs
)
encoder_output
,
decoder_output
=
model
(
input_ids
,
decoder_attention_mask
=
input_mask
,
encoder_input_ids
=
input_ids
)
encoder_output
,
decoder_output
=
model
(
input_ids
,
decoder_attention_mask
=
input_mask
,
encoder_input_ids
=
input_ids
)
result
=
{
"encoder_output"
:
encoder_output
.
numpy
(),
"decoder_output"
:
decoder_output
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"encoder_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
list
(
result
[
"encoder_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"decoder_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
list
(
result
[
"decoder_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
def
create_and_check_t5_with_lm_head
(
self
,
config
,
input_ids
,
input_mask
,
token_labels
):
model
=
TFT5WithLMHeadModel
(
config
=
config
)
inputs
=
{
'encoder_input_ids'
:
input_ids
,
'decoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
inputs
=
{
"encoder_input_ids"
:
input_ids
,
"decoder_input_ids"
:
input_ids
,
"decoder_attention_mask"
:
input_mask
,
}
prediction_scores
,
decoder_output
=
model
(
inputs
)
result
=
{
"prediction_scores"
:
prediction_scores
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
token_labels
)
=
config_and_inputs
inputs_dict
=
{
'encoder_input_ids'
:
input_ids
,
'decoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
inputs_dict
=
{
"encoder_input_ids"
:
input_ids
,
"decoder_input_ids"
:
input_ids
,
"decoder_attention_mask"
:
input_mask
,
}
return
config
,
inputs_dict
def
setUp
(
self
):
...
...
@@ -161,9 +162,10 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
[
'
t5-small
'
]:
for
model_name
in
[
"
t5-small
"
]:
model
=
TFT5Model
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
transformers/tests/modeling_tf_transfo_xl_test.py
View file @
54abc67a
...
...
@@ -12,24 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
unittest
import
random
import
unittest
from
transformers
import
TransfoXLConfig
,
is_tf_available
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.modeling_tf_common_test
import
TFCommonTestCases
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
transformers
import
TransfoXLConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers.modeling_tf_transfo_xl
import
(
TFTransfoXLModel
,
from
transformers.modeling_tf_transfo_xl
import
(
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
@
require_tf
...
...
@@ -41,8 +42,8 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
test_resize_embeddings
=
False
class
TFTransfoXLModelTester
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
...
...
@@ -101,7 +102,8 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
d_head
=
self
.
d_head
,
d_inner
=
self
.
d_inner
,
div_val
=
self
.
div_val
,
n_layer
=
self
.
num_hidden_layers
)
n_layer
=
self
.
num_hidden_layers
,
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
...
...
@@ -114,8 +116,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
hidden_states_1
,
mems_1
=
model
(
input_ids_1
)
inputs
=
{
'input_ids'
:
input_ids_2
,
'mems'
:
mems_1
}
inputs
=
{
"input_ids"
:
input_ids_2
,
"mems"
:
mems_1
}
hidden_states_2
,
mems_2
=
model
(
inputs
)
...
...
@@ -127,33 +128,31 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
list
(
result
[
"hidden_states_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
list
(
result
[
"hidden_states_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_and_check_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TFTransfoXLLMHeadModel
(
config
)
lm_logits_1
,
mems_1
=
model
(
input_ids_1
)
inputs
=
{
'input_ids'
:
input_ids_1
,
'labels'
:
lm_labels
}
inputs
=
{
"input_ids"
:
input_ids_1
,
"labels"
:
lm_labels
}
_
,
mems_1
=
model
(
inputs
)
lm_logits_2
,
mems_2
=
model
([
input_ids_2
,
mems_1
])
inputs
=
{
'input_ids'
:
input_ids_1
,
'mems'
:
mems_1
,
'labels'
:
lm_labels
}
inputs
=
{
"input_ids"
:
input_ids_1
,
"mems"
:
mems_1
,
"labels"
:
lm_labels
}
_
,
mems_2
=
model
(
inputs
)
...
...
@@ -165,26 +164,27 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
list
(
result
[
"lm_logits_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
list
(
result
[
"lm_logits_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
=
config_and_inputs
inputs_dict
=
{
'
input_ids
'
:
input_ids_1
}
inputs_dict
=
{
"
input_ids
"
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFTransfoXLModelTest
.
TFTransfoXLModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
...
...
transformers/tests/modeling_tf_xlm_test.py
View file @
54abc67a
...
...
@@ -12,38 +12,41 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
unittest
from
transformers
import
is_tf_available
from
.configuration_common_test
import
ConfigTester
from
.modeling_tf_common_test
import
TFCommonTestCases
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers
import
(
XLMConfig
,
TFXLMModel
,
from
transformers
import
(
XLMConfig
,
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
@
require_tf
class
TFXLMModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
)
if
is_tf_available
()
else
()
all_model_classes
=
(
(
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
)
if
is_tf_available
()
else
()
)
class
TFXLMModelTester
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
...
...
@@ -109,7 +112,9 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
input_lengths
=
None
if
self
.
use_input_lengths
:
input_lengths
=
ids_tensor
([
self
.
batch_size
],
vocab_size
=
2
)
+
self
.
seq_length
-
2
# small variation of seq_length
input_lengths
=
(
ids_tensor
([
self
.
batch_size
],
vocab_size
=
2
)
+
self
.
seq_length
-
2
)
# small variation of seq_length
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
@@ -139,15 +144,33 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
max_position_embeddings
=
self
.
max_position_embeddings
,
initializer_range
=
self
.
initializer_range
,
summary_type
=
self
.
summary_type
,
use_proj
=
self
.
use_proj
)
return
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
use_proj
=
self
.
use_proj
,
)
return
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
)
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
):
model
=
TFXLMModel
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'lengths'
:
input_lengths
,
'langs'
:
token_type_ids
}
inputs
=
{
"input_ids"
:
input_ids
,
"lengths"
:
input_lengths
,
"langs"
:
token_type_ids
}
outputs
=
model
(
inputs
)
inputs
=
[
input_ids
,
input_mask
]
...
...
@@ -157,16 +180,23 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
"sequence_output"
:
sequence_output
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_xlm_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
list
(
result
[
"sequence_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
def
create_and_check_xlm_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
):
model
=
TFXLMWithLMHeadModel
(
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'lengths'
:
input_lengths
,
'langs'
:
token_type_ids
}
inputs
=
{
"input_ids"
:
input_ids
,
"lengths"
:
input_lengths
,
"langs"
:
token_type_ids
}
outputs
=
model
(
inputs
)
logits
=
outputs
[
0
]
...
...
@@ -176,15 +206,23 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_xlm_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
def
create_and_check_xlm_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
):
model
=
TFXLMForQuestionAnsweringSimple
(
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'lengths'
:
input_lengths
}
inputs
=
{
"input_ids"
:
input_ids
,
"lengths"
:
input_lengths
}
outputs
=
model
(
inputs
)
start_logits
,
end_logits
=
model
(
inputs
)
...
...
@@ -194,19 +232,23 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
"end_logits"
:
end_logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
def
create_and_check_xlm_sequence_classif
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
def
create_and_check_xlm_sequence_classif
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
):
model
=
TFXLMForSequenceClassification
(
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'lengths'
:
input_lengths
}
inputs
=
{
"input_ids"
:
input_ids
,
"lengths"
:
input_lengths
}
(
logits
,)
=
model
(
inputs
)
...
...
@@ -214,16 +256,26 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'langs'
:
token_type_ids
,
'lengths'
:
input_lengths
}
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
)
=
config_and_inputs
inputs_dict
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"langs"
:
token_type_ids
,
"lengths"
:
input_lengths
,
}
return
config
,
inputs_dict
def
setUp
(
self
):
...
...
transformers/tests/modeling_tf_xlnet_test.py
View file @
54abc67a
...
...
@@ -12,43 +12,50 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
os
import
unittest
import
json
import
random
import
unittest
from
transformers
import
XLNetConfig
,
is_tf_available
from
.configuration_common_test
import
ConfigTester
from
.modeling_tf_common_test
import
TFCommonTestCases
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers.modeling_tf_xlnet
import
(
TFXLNetModel
,
TFXLNetLMHeadModel
,
from
transformers.modeling_tf_xlnet
import
(
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
TFXLNetForTokenClassification
,
TFXLNetForQuestionAnsweringSimple
,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
@
require_tf
class
TFXLNetModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFXLNetModel
,
TFXLNetLMHeadModel
,
all_model_classes
=
(
(
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
TFXLNetForTokenClassification
,
TFXLNetForQuestionAnsweringSimple
)
if
is_tf_available
()
else
()
TFXLNetForQuestionAnsweringSimple
,
)
if
is_tf_available
()
else
()
)
test_pruning
=
False
class
TFXLNetModelTester
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
...
...
@@ -131,22 +138,44 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
,
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
num_labels
=
self
.
type_sequence_label_size
,
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
tf
.
random
.
set_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
):
model
=
TFXLNetModel
(
config
)
inputs
=
{
'input_ids'
:
input_ids_1
,
'input_mask'
:
input_mask
,
'token_type_ids'
:
segment_ids
}
inputs
=
{
"input_ids"
:
input_ids_1
,
"input_mask"
:
input_mask
,
"token_type_ids"
:
segment_ids
}
_
,
_
=
model
(
inputs
)
...
...
@@ -165,30 +194,38 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
self
.
parent
.
assertEqual
(
len
(
no_mems_outputs
),
1
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"outputs"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
list
(
result
[
"outputs"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
):
model
=
TFXLNetLMHeadModel
(
config
)
inputs_1
=
{
'input_ids'
:
input_ids_1
,
'token_type_ids'
:
segment_ids
}
inputs_1
=
{
"input_ids"
:
input_ids_1
,
"token_type_ids"
:
segment_ids
}
all_logits_1
,
mems_1
=
model
(
inputs_1
)
inputs_2
=
{
'input_ids'
:
input_ids_2
,
'mems'
:
mems_1
,
'token_type_ids'
:
segment_ids
}
inputs_2
=
{
"input_ids"
:
input_ids_2
,
"mems"
:
mems_1
,
"token_type_ids"
:
segment_ids
}
all_logits_2
,
mems_2
=
model
(
inputs_2
)
inputs_3
=
{
'input_ids'
:
input_ids_q
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
inputs_3
=
{
"input_ids"
:
input_ids_q
,
"perm_mask"
:
perm_mask
,
"target_mapping"
:
target_mapping
}
logits
,
_
=
model
(
inputs_3
)
...
...
@@ -200,26 +237,38 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"all_logits_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
list
(
result
[
"all_logits_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"all_logits_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
list
(
result
[
"all_logits_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
):
model
=
TFXLNetForQuestionAnsweringSimple
(
config
)
inputs
=
{
'input_ids'
:
input_ids_1
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
segment_ids
}
inputs
=
{
"input_ids"
:
input_ids_1
,
"attention_mask"
:
input_mask
,
"token_type_ids"
:
segment_ids
}
start_logits
,
end_logits
,
mems
=
model
(
inputs
)
result
=
{
...
...
@@ -228,18 +277,27 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
"mems"
:
[
m
.
numpy
()
for
m
in
mems
],
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
):
model
=
TFXLNetForSequenceClassification
(
config
)
logits
,
mems_1
=
model
(
input_ids_1
)
...
...
@@ -249,19 +307,31 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_for_token_classification
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_and_check_xlnet_for_token_classification
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
):
config
.
num_labels
=
input_ids_1
.
shape
[
1
]
model
=
TFXLNetForTokenClassification
(
config
)
inputs
=
{
'input_ids'
:
input_ids_1
,
'attention_mask'
:
input_mask
,
inputs
=
{
"input_ids"
:
input_ids_1
,
"attention_mask"
:
input_mask
,
# 'token_type_ids': token_type_ids
}
logits
,
mems_1
=
model
(
inputs
)
...
...
@@ -270,21 +340,31 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
config
.
num_labels
]
)
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
config
.
num_labels
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
)
=
config_and_inputs
inputs_dict
=
{
"input_ids"
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFXLNetModelTest
.
TFXLNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
...
...
transformers/tests/modeling_transfo_xl_test.py
View file @
54abc67a
...
...
@@ -12,24 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
unittest
import
random
import
unittest
from
transformers
import
is_torch_available
from
.configuration_common_test
import
ConfigTester
from
.modeling_common_test
import
CommonTestCases
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
if
is_torch_available
():
import
torch
from
transformers
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
transformers
import
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
from
transformers.modeling_transfo_xl
import
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
@
require_torch
class
TransfoXLModelTest
(
CommonTestCases
.
CommonModelTester
):
...
...
@@ -40,8 +39,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
test_resize_embeddings
=
False
class
TransfoXLModelTester
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
...
...
@@ -100,7 +99,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
d_head
=
self
.
d_head
,
d_inner
=
self
.
d_inner
,
div_val
=
self
.
div_val
,
n_layer
=
self
.
num_hidden_layers
)
n_layer
=
self
.
num_hidden_layers
,
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
...
...
@@ -125,18 +125,19 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
def
check_transfo_xl_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
list
(
result
[
"hidden_states_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
list
(
result
[
"hidden_states_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TransfoXLLMHeadModel
(
config
)
...
...
@@ -159,33 +160,30 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
return
outputs
def
check_transfo_xl_lm_head_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
list
(
result
[
"lm_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
list
(
result
[
"lm_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
=
config_and_inputs
inputs_dict
=
{
'
input_ids
'
:
input_ids_1
}
inputs_dict
=
{
"
input_ids
"
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TransfoXLModelTest
.
TransfoXLModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
...
...
transformers/tests/modeling_xlm_test.py
View file @
54abc67a
...
...
@@ -12,34 +12,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
unittest
from
transformers
import
is_torch_available
if
is_torch_available
():
from
transformers
import
(
XLMConfig
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
,
XLMForQuestionAnsweringSimple
)
from
transformers.modeling_xlm
import
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.modeling_common_test
import
CommonTestCases
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
if
is_torch_available
():
from
transformers
import
(
XLMConfig
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
,
XLMForQuestionAnsweringSimple
,
)
from
transformers.modeling_xlm
import
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
@
require_torch
class
XLMModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
,
XLMForQuestionAnsweringSimple
)
if
is_torch_available
()
else
()
all_model_classes
=
(
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
,
XLMForQuestionAnsweringSimple
,
)
if
is_torch_available
()
else
()
)
class
XLMModelTester
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
...
...
@@ -105,7 +118,9 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
input_lengths
=
None
if
self
.
use_input_lengths
:
input_lengths
=
ids_tensor
([
self
.
batch_size
],
vocab_size
=
2
)
+
self
.
seq_length
-
2
# small variation of seq_length
input_lengths
=
(
ids_tensor
([
self
.
batch_size
],
vocab_size
=
2
)
+
self
.
seq_length
-
2
)
# small variation of seq_length
token_type_ids
=
None
if
self
.
use_token_type_ids
:
...
...
@@ -135,16 +150,34 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
max_position_embeddings
=
self
.
max_position_embeddings
,
initializer_range
=
self
.
initializer_range
,
summary_type
=
self
.
summary_type
,
use_proj
=
self
.
use_proj
)
return
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
use_proj
=
self
.
use_proj
,
)
return
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
)
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
):
model
=
XLMModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
...
...
@@ -156,11 +189,20 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"sequence_output"
:
sequence_output
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_xlm_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
def
create_and_check_xlm_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
):
model
=
XLMWithLMHeadModel
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
...
...
@@ -172,23 +214,29 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_xlm_simple_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
def
create_and_check_xlm_simple_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
):
model
=
XLMForQuestionAnsweringSimple
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
outputs
=
model
(
input_ids
)
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
loss
,
start_logits
,
end_logits
=
outputs
result
=
{
...
...
@@ -196,16 +244,21 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
check_loss_output
(
result
)
def
create_and_check_xlm_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
def
create_and_check_xlm_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
):
model
=
XLMForQuestionAnswering
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
...
...
@@ -213,21 +266,26 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
outputs
=
model
(
input_ids
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
=
outputs
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
p_mask
=
input_mask
,
)
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
is_impossible
=
is_impossible_labels
,
)
(
total_loss
,)
=
outputs
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
(
total_loss
,)
=
outputs
...
...
@@ -240,27 +298,34 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"cls_logits"
:
cls_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
)
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
def
create_and_check_xlm_sequence_classif
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
def
create_and_check_xlm_sequence_classif
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
):
model
=
XLMForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
...
...
@@ -273,19 +338,24 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
]
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'lengths'
:
input_lengths
}
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
,
)
=
config_and_inputs
inputs_dict
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"lengths"
:
input_lengths
}
return
config
,
inputs_dict
def
setUp
(
self
):
...
...
transformers/tests/modeling_xlnet_test.py
View file @
54abc67a
...
...
@@ -12,39 +12,51 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
os
import
unittest
import
json
import
random
import
unittest
from
transformers
import
is_torch_available
from
.configuration_common_test
import
ConfigTester
from
.modeling_common_test
import
CommonTestCases
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
if
is_torch_available
():
import
torch
from
transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForTokenClassification
,
XLNetForQuestionAnswering
)
from
transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForTokenClassification
,
XLNetForQuestionAnswering
,
)
from
transformers.modeling_xlnet
import
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
@
require_torch
class
XLNetModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForTokenClassification
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
if
is_torch_available
()
else
()
all_model_classes
=
(
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForTokenClassification
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
,
)
if
is_torch_available
()
else
()
)
test_pruning
=
False
class
XLNetModelTester
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
...
...
@@ -97,9 +109,13 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
)
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
)
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
sequence_labels
=
None
...
...
@@ -125,17 +141,43 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
,
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
)
num_labels
=
self
.
type_sequence_label_size
,
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
):
model
=
XLNetModel
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
...
...
@@ -158,14 +200,28 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
self
.
parent
.
assertEqual
(
len
(
no_mems_outputs
),
1
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"outputs"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
list
(
result
[
"outputs"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_base_model_with_att_output
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_and_check_xlnet_base_model_with_att_output
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
):
model
=
XLNetModel
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
...
...
@@ -177,15 +233,30 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
self
.
parent
.
assertEqual
(
len
(
attentions
[
0
]),
2
)
self
.
parent
.
assertTrue
(
attentions
[
0
][
0
].
shape
,
attentions
[
0
][
0
].
shape
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
):
model
=
XLNetLMHeadModel
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
loss_1
,
all_logits_1
,
mems_1
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
)
loss_2
,
all_logits_2
,
mems_2
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
,
mems
=
mems_1
)
loss_2
,
all_logits_2
,
mems_2
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
,
mems
=
mems_1
)
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
)
...
...
@@ -198,28 +269,39 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"all_logits_2"
:
all_logits_2
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"all_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
list
(
result
[
"all_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"all_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
list
(
result
[
"all_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
):
model
=
XLNetForQuestionAnswering
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
...
...
@@ -227,21 +309,26 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
outputs
=
model
(
input_ids_1
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
p_mask
=
input_mask
,
)
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
is_impossible
=
is_impossible_labels
,
)
total_loss
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
total_loss
,
mems
=
outputs
...
...
@@ -255,30 +342,42 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"mems"
:
mems
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
)
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
]
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
],
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_token_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_and_check_xlnet_token_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
):
model
=
XLNetForTokenClassification
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
...
...
@@ -292,26 +391,30 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
type_sequence_label_size
])
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
type_sequence_label_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
):
model
=
XLNetForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
...
...
@@ -325,25 +428,34 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
]
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
,
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
,
)
=
config_and_inputs
inputs_dict
=
{
"input_ids"
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
XLNetModelTest
.
XLNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
...
...
transformers/tests/optimization_test.py
View file @
54abc67a
...
...
@@ -12,27 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
unittest
import
os
import
unittest
from
transformers
import
is_torch_available
from
.tokenization_tests_commons
import
TemporaryDirectory
from
.utils
import
require_torch
if
is_torch_available
():
import
torch
from
transformers
import
(
AdamW
,
from
transformers
import
(
AdamW
,
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
get_cosine_with_hard_restarts_schedule_with_warmup
,
get_linear_schedule_with_warmup
)
from
.tokenization_tests_commons
import
TemporaryDirectory
from
.utils
import
require_torch
get_linear_schedule_with_warmup
,
)
def
unwrap_schedule
(
scheduler
,
num_steps
=
10
):
...
...
@@ -42,6 +43,7 @@ def unwrap_schedule(scheduler, num_steps=10):
lrs
.
append
(
scheduler
.
get_lr
())
return
lrs
def
unwrap_and_save_reload_schedule
(
scheduler
,
num_steps
=
10
):
lrs
=
[]
for
step
in
range
(
num_steps
):
...
...
@@ -49,16 +51,16 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
lrs
.
append
(
scheduler
.
get_lr
())
if
step
==
num_steps
//
2
:
with
TemporaryDirectory
()
as
tmpdirname
:
file_name
=
os
.
path
.
join
(
tmpdirname
,
'
schedule.bin
'
)
file_name
=
os
.
path
.
join
(
tmpdirname
,
"
schedule.bin
"
)
torch
.
save
(
scheduler
.
state_dict
(),
file_name
)
state_dict
=
torch
.
load
(
file_name
)
scheduler
.
load_state_dict
(
state_dict
)
return
lrs
@
require_torch
class
OptimizationTest
(
unittest
.
TestCase
):
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
self
.
assertEqual
(
len
(
list1
),
len
(
list2
))
for
a
,
b
in
zip
(
list1
,
list2
):
...
...
@@ -82,7 +84,7 @@ class OptimizationTest(unittest.TestCase):
@
require_torch
class
ScheduleInitTest
(
unittest
.
TestCase
):
m
=
torch
.
nn
.
Linear
(
50
,
50
)
if
is_torch_available
()
else
None
optimizer
=
AdamW
(
m
.
parameters
(),
lr
=
10.
)
if
is_torch_available
()
else
None
optimizer
=
AdamW
(
m
.
parameters
(),
lr
=
10.
0
)
if
is_torch_available
()
else
None
num_steps
=
10
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
...
...
@@ -93,7 +95,7 @@ class ScheduleInitTest(unittest.TestCase):
def
test_constant_scheduler
(
self
):
scheduler
=
get_constant_schedule
(
self
.
optimizer
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
10.
]
*
self
.
num_steps
expected_learning_rates
=
[
10.
0
]
*
self
.
num_steps
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
...
...
@@ -135,13 +137,17 @@ class ScheduleInitTest(unittest.TestCase):
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
def
test_warmup_cosine_hard_restart_scheduler
(
self
):
scheduler
=
get_cosine_with_hard_restarts_schedule_with_warmup
(
self
.
optimizer
,
num_warmup_steps
=
2
,
num_cycles
=
2
,
num_training_steps
=
10
)
scheduler
=
get_cosine_with_hard_restarts_schedule_with_warmup
(
self
.
optimizer
,
num_warmup_steps
=
2
,
num_cycles
=
2
,
num_training_steps
=
10
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
5.0
,
10.0
,
8.53
,
5.0
,
1.46
,
10.0
,
8.53
,
5.0
,
1.46
,
0.0
]
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListAlmostEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
,
tol
=
1e-2
)
scheduler
=
get_cosine_with_hard_restarts_schedule_with_warmup
(
self
.
optimizer
,
num_warmup_steps
=
2
,
num_cycles
=
2
,
num_training_steps
=
10
)
scheduler
=
get_cosine_with_hard_restarts_schedule_with_warmup
(
self
.
optimizer
,
num_warmup_steps
=
2
,
num_cycles
=
2
,
num_training_steps
=
10
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
...
...
transformers/tests/optimization_tf_test.py
View file @
54abc67a
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
unittest
...
...
@@ -8,11 +6,12 @@ from transformers import is_tf_available
from
.utils
import
require_tf
if
is_tf_available
():
import
tensorflow
as
tf
from
tensorflow.python.eager
import
context
from
tensorflow.python.framework
import
ops
from
transformers
import
(
create_optimizer
,
GradientAccumulator
)
from
transformers
import
create_optimizer
,
GradientAccumulator
@
require_tf
...
...
@@ -42,8 +41,8 @@ class OptimizationFTest(unittest.TestCase):
physical_devices
=
tf
.
config
.
experimental
.
list_physical_devices
(
"CPU"
)
tf
.
config
.
experimental
.
set_virtual_device_configuration
(
physical_devices
[
0
],
[
tf
.
config
.
experimental
.
VirtualDeviceConfiguration
(),
tf
.
config
.
experimental
.
VirtualDeviceConfiguration
()]
)
[
tf
.
config
.
experimental
.
VirtualDeviceConfiguration
(),
tf
.
config
.
experimental
.
VirtualDeviceConfiguration
()],
)
devices
=
tf
.
config
.
experimental
.
list_logical_devices
(
device_type
=
"CPU"
)
strategy
=
tf
.
distribute
.
MirroredStrategy
(
devices
=
[
device
.
name
for
device
in
devices
])
...
...
transformers/tests/pipelines_test.py
View file @
54abc67a
import
unittest
from
typing
import
Iterable
from
transformers
import
pipeline
from
transformers.tests.utils
import
require_tf
,
require_torch
QA_FINETUNED_MODELS
=
{
(
'
bert-base-uncased
'
,
'
bert-large-uncased-whole-word-masking-finetuned-squad
'
,
None
),
(
'
bert-base-cased
'
,
'
bert-large-cased-whole-word-masking-finetuned-squad
'
,
None
),
(
'
bert-base-uncased
'
,
'
distilbert-base-uncased-distilled-squad
'
,
None
)
(
"
bert-base-uncased
"
,
"
bert-large-uncased-whole-word-masking-finetuned-squad
"
,
None
),
(
"
bert-base-cased
"
,
"
bert-large-cased-whole-word-masking-finetuned-squad
"
,
None
),
(
"
bert-base-uncased
"
,
"
distilbert-base-uncased-distilled-squad
"
,
None
)
,
}
TF_QA_FINETUNED_MODELS
=
{
(
'
bert-base-uncased
'
,
'
bert-large-uncased-whole-word-masking-finetuned-squad
'
,
None
),
(
'
bert-base-cased
'
,
'
bert-large-cased-whole-word-masking-finetuned-squad
'
,
None
),
(
'
bert-base-uncased
'
,
'
distilbert-base-uncased-distilled-squad
'
,
None
)
(
"
bert-base-uncased
"
,
"
bert-large-uncased-whole-word-masking-finetuned-squad
"
,
None
),
(
"
bert-base-cased
"
,
"
bert-large-cased-whole-word-masking-finetuned-squad
"
,
None
),
(
"
bert-base-uncased
"
,
"
distilbert-base-uncased-distilled-squad
"
,
None
)
,
}
TF_NER_FINETUNED_MODELS
=
{
(
'
bert-base-cased
'
,
'
https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-tf_model.h5
'
,
'
https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json
'
"
bert-base-cased
"
,
"
https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-tf_model.h5
"
,
"
https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json
"
,
)
}
NER_FINETUNED_MODELS
=
{
(
'
bert-base-cased
'
,
'
https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin
'
,
'
https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json
'
"
bert-base-cased
"
,
"
https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin
"
,
"
https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json
"
,
)
}
FEATURE_EXTRACT_FINETUNED_MODELS
=
{
(
'
bert-base-cased
'
,
'
bert-base-cased
'
,
None
),
(
"
bert-base-cased
"
,
"
bert-base-cased
"
,
None
),
# ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2
(
'
distilbert-base-uncased
'
,
'
distilbert-base-uncased
'
,
None
)
(
"
distilbert-base-uncased
"
,
"
distilbert-base-uncased
"
,
None
)
,
}
TF_FEATURE_EXTRACT_FINETUNED_MODELS
=
{
(
'
bert-base-cased
'
,
'
bert-base-cased
'
,
None
),
(
"
bert-base-cased
"
,
"
bert-base-cased
"
,
None
),
# ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2
(
'
distilbert-base-uncased
'
,
'
distilbert-base-uncased
'
,
None
)
(
"
distilbert-base-uncased
"
,
"
distilbert-base-uncased
"
,
None
)
,
}
TF_TEXT_CLASSIF_FINETUNED_MODELS
=
{
(
'
bert-base-uncased
'
,
'
https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-tf_model.h5
'
,
'
https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json
'
"
bert-base-uncased
"
,
"
https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-tf_model.h5
"
,
"
https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json
"
,
)
}
TEXT_CLASSIF_FINETUNED_MODELS
=
{
(
'
bert-base-uncased
'
,
'
https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin
'
,
'
https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json
'
"
bert-base-uncased
"
,
"
https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin
"
,
"
https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json
"
,
)
}
...
...
@@ -91,54 +91,54 @@ class MonoColumnInputTestCase(unittest.TestCase):
@
require_torch
def
test_ner
(
self
):
mandatory_keys
=
{
'
entity
'
,
'
word
'
,
'
score
'
}
valid_inputs
=
[
'
HuggingFace is solving NLP one commit at a time.
'
,
'
HuggingFace is based in New-York & Paris
'
]
mandatory_keys
=
{
"
entity
"
,
"
word
"
,
"
score
"
}
valid_inputs
=
[
"
HuggingFace is solving NLP one commit at a time.
"
,
"
HuggingFace is based in New-York & Paris
"
]
invalid_inputs
=
[
None
]
for
tokenizer
,
model
,
config
in
NER_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
'
ner
'
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
nlp
=
pipeline
(
task
=
"
ner
"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
invalid_inputs
,
mandatory_keys
)
@
require_tf
def
test_tf_ner
(
self
):
mandatory_keys
=
{
'
entity
'
,
'
word
'
,
'
score
'
}
valid_inputs
=
[
'
HuggingFace is solving NLP one commit at a time.
'
,
'
HuggingFace is based in New-York & Paris
'
]
mandatory_keys
=
{
"
entity
"
,
"
word
"
,
"
score
"
}
valid_inputs
=
[
"
HuggingFace is solving NLP one commit at a time.
"
,
"
HuggingFace is based in New-York & Paris
"
]
invalid_inputs
=
[
None
]
for
tokenizer
,
model
,
config
in
TF_NER_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
'
ner
'
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
nlp
=
pipeline
(
task
=
"
ner
"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
invalid_inputs
,
mandatory_keys
)
@
require_torch
def
test_sentiment_analysis
(
self
):
mandatory_keys
=
{
'
label
'
}
valid_inputs
=
[
'
HuggingFace is solving NLP one commit at a time.
'
,
'
HuggingFace is based in New-York & Paris
'
]
mandatory_keys
=
{
"
label
"
}
valid_inputs
=
[
"
HuggingFace is solving NLP one commit at a time.
"
,
"
HuggingFace is based in New-York & Paris
"
]
invalid_inputs
=
[
None
]
for
tokenizer
,
model
,
config
in
TEXT_CLASSIF_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
'
sentiment-analysis
'
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
nlp
=
pipeline
(
task
=
"
sentiment-analysis
"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
invalid_inputs
,
mandatory_keys
)
@
require_tf
def
test_tf_sentiment_analysis
(
self
):
mandatory_keys
=
{
'
label
'
}
valid_inputs
=
[
'
HuggingFace is solving NLP one commit at a time.
'
,
'
HuggingFace is based in New-York & Paris
'
]
mandatory_keys
=
{
"
label
"
}
valid_inputs
=
[
"
HuggingFace is solving NLP one commit at a time.
"
,
"
HuggingFace is based in New-York & Paris
"
]
invalid_inputs
=
[
None
]
for
tokenizer
,
model
,
config
in
TF_TEXT_CLASSIF_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
'
sentiment-analysis
'
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
nlp
=
pipeline
(
task
=
"
sentiment-analysis
"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
invalid_inputs
,
mandatory_keys
)
@
require_torch
def
test_features_extraction
(
self
):
valid_inputs
=
[
'
HuggingFace is solving NLP one commit at a time.
'
,
'
HuggingFace is based in New-York & Paris
'
]
valid_inputs
=
[
"
HuggingFace is solving NLP one commit at a time.
"
,
"
HuggingFace is based in New-York & Paris
"
]
invalid_inputs
=
[
None
]
for
tokenizer
,
model
,
config
in
FEATURE_EXTRACT_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
'
sentiment-analysis
'
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
nlp
=
pipeline
(
task
=
"
sentiment-analysis
"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
invalid_inputs
,
{})
@
require_tf
def
test_tf_features_extraction
(
self
):
valid_inputs
=
[
'
HuggingFace is solving NLP one commit at a time.
'
,
'
HuggingFace is based in New-York & Paris
'
]
valid_inputs
=
[
"
HuggingFace is solving NLP one commit at a time.
"
,
"
HuggingFace is based in New-York & Paris
"
]
invalid_inputs
=
[
None
]
for
tokenizer
,
model
,
config
in
TF_FEATURE_EXTRACT_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
'
sentiment-analysis
'
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
nlp
=
pipeline
(
task
=
"
sentiment-analysis
"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
invalid_inputs
,
{})
...
...
@@ -165,46 +165,46 @@ class MultiColumnInputTestCase(unittest.TestCase):
@
require_torch
def
test_question_answering
(
self
):
mandatory_output_keys
=
{
'
score
'
,
'
answer
'
,
'
start
'
,
'
end
'
}
mandatory_output_keys
=
{
"
score
"
,
"
answer
"
,
"
start
"
,
"
end
"
}
valid_samples
=
[
{
'
question
'
:
'
Where was HuggingFace founded ?
'
,
'
context
'
:
'
HuggingFace was founded in Paris.
'
},
{
"
question
"
:
"
Where was HuggingFace founded ?
"
,
"
context
"
:
"
HuggingFace was founded in Paris.
"
},
{
'
question
'
:
'
In what field is HuggingFace working ?
'
,
'
context
'
:
'
HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.
'
}
"
question
"
:
"
In what field is HuggingFace working ?
"
,
"
context
"
:
"
HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.
"
,
}
,
]
invalid_samples
=
[
{
'
question
'
:
''
,
'
context
'
:
'
This is a test to try empty question edge case
'
},
{
'
question
'
:
None
,
'
context
'
:
'
This is a test to try empty question edge case
'
},
{
'
question
'
:
'
What is does with empty context ?
'
,
'
context
'
:
''
},
{
'
question
'
:
'
What is does with empty context ?
'
,
'
context
'
:
None
},
{
"
question
"
:
""
,
"
context
"
:
"
This is a test to try empty question edge case
"
},
{
"
question
"
:
None
,
"
context
"
:
"
This is a test to try empty question edge case
"
},
{
"
question
"
:
"
What is does with empty context ?
"
,
"
context
"
:
""
},
{
"
question
"
:
"
What is does with empty context ?
"
,
"
context
"
:
None
},
]
for
tokenizer
,
model
,
config
in
QA_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
'
question-answering
'
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
nlp
=
pipeline
(
task
=
"
question-answering
"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
self
.
_test_multicolumn_pipeline
(
nlp
,
valid_samples
,
invalid_samples
,
mandatory_output_keys
)
@
require_tf
def
test_tf_question_answering
(
self
):
mandatory_output_keys
=
{
'
score
'
,
'
answer
'
,
'
start
'
,
'
end
'
}
mandatory_output_keys
=
{
"
score
"
,
"
answer
"
,
"
start
"
,
"
end
"
}
valid_samples
=
[
{
'
question
'
:
'
Where was HuggingFace founded ?
'
,
'
context
'
:
'
HuggingFace was founded in Paris.
'
},
{
"
question
"
:
"
Where was HuggingFace founded ?
"
,
"
context
"
:
"
HuggingFace was founded in Paris.
"
},
{
'
question
'
:
'
In what field is HuggingFace working ?
'
,
'
context
'
:
'
HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.
'
}
"
question
"
:
"
In what field is HuggingFace working ?
"
,
"
context
"
:
"
HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.
"
,
}
,
]
invalid_samples
=
[
{
'
question
'
:
''
,
'
context
'
:
'
This is a test to try empty question edge case
'
},
{
'
question
'
:
None
,
'
context
'
:
'
This is a test to try empty question edge case
'
},
{
'
question
'
:
'
What is does with empty context ?
'
,
'
context
'
:
''
},
{
'
question
'
:
'
What is does with empty context ?
'
,
'
context
'
:
None
},
{
"
question
"
:
""
,
"
context
"
:
"
This is a test to try empty question edge case
"
},
{
"
question
"
:
None
,
"
context
"
:
"
This is a test to try empty question edge case
"
},
{
"
question
"
:
"
What is does with empty context ?
"
,
"
context
"
:
""
},
{
"
question
"
:
"
What is does with empty context ?
"
,
"
context
"
:
None
},
]
for
tokenizer
,
model
,
config
in
TF_QA_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
'
question-answering
'
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
nlp
=
pipeline
(
task
=
"
question-answering
"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
self
.
_test_multicolumn_pipeline
(
nlp
,
valid_samples
,
invalid_samples
,
mandatory_output_keys
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unittest
.
main
()
transformers/tests/tokenization_albert_test.py
View file @
54abc67a
...
...
@@ -17,12 +17,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
unittest
from
transformers.tokenization_albert
import
(
AlbertTokenizer
,
SPIECE_UNDERLINE
)
from
transformers.tokenization_albert
import
AlbertTokenizer
from
.tokenization_tests_commons
import
CommonTestCases
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'fixtures/spiece.model'
)
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/spiece.model"
)
class
AlbertTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
...
...
@@ -39,27 +40,30 @@ class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
return
AlbertTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"this is a test"
output_text
=
u
"this is a test"
input_text
=
"this is a test"
output_text
=
"this is a test"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
AlbertTokenizer
(
SAMPLE_VOCAB
,
keep_accents
=
True
)
tokens
=
tokenizer
.
tokenize
(
u
'
This is a test
'
)
self
.
assertListEqual
(
tokens
,
[
u
'
▁this
'
,
u
'
▁is
'
,
u
'
▁a
'
,
u
'
▁test
'
])
tokens
=
tokenizer
.
tokenize
(
"
This is a test
"
)
self
.
assertListEqual
(
tokens
,
[
"
▁this
"
,
"
▁is
"
,
"
▁a
"
,
"
▁test
"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
48
,
25
,
21
,
1289
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
48
,
25
,
21
,
1289
])
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
u
'▁i'
,
u
'▁was'
,
u
'▁born'
,
u
'▁in'
,
u
'▁9'
,
u
'2000'
,
u
','
,
u
'▁and'
,
u
'▁this'
,
u
'▁is'
,
u
'▁fal'
,
u
's'
,
u
'é'
,
u
'.'
])
tokens
=
tokenizer
.
tokenize
(
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
"▁i"
,
"▁was"
,
"▁born"
,
"▁in"
,
"▁9"
,
"2000"
,
","
,
"▁and"
,
"▁this"
,
"▁is"
,
"▁fal"
,
"s"
,
"é"
,
"."
]
)
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
self
.
assertListEqual
(
ids
,
[
31
,
23
,
386
,
19
,
561
,
3050
,
15
,
17
,
48
,
25
,
8256
,
18
,
1
,
9
])
back_tokens
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
self
.
assertListEqual
(
back_tokens
,
[
'▁i'
,
'▁was'
,
'▁born'
,
'▁in'
,
'▁9'
,
'2000'
,
','
,
'▁and'
,
'▁this'
,
'▁is'
,
'▁fal'
,
's'
,
'<unk>'
,
'.'
])
self
.
assertListEqual
(
back_tokens
,
[
"▁i"
,
"▁was"
,
"▁born"
,
"▁in"
,
"▁9"
,
"2000"
,
","
,
"▁and"
,
"▁this"
,
"▁is"
,
"▁fal"
,
"s"
,
"<unk>"
,
"."
],
)
def
test_sequence_builders
(
self
):
tokenizer
=
AlbertTokenizer
(
SAMPLE_VOCAB
)
...
...
@@ -71,8 +75,10 @@ class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
encoded_pair
=
tokenizer
.
build_inputs_with_special_tokens
(
text
,
text_2
)
assert
encoded_sentence
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
assert
encoded_pair
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
+
text_2
+
[
tokenizer
.
sep_token_id
]
assert
encoded_pair
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
+
text_2
+
[
tokenizer
.
sep_token_id
]
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unittest
.
main
()
transformers/tests/tokenization_auto_test.py
View file @
54abc67a
...
...
@@ -12,18 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
unittest
import
shutil
import
logging
import
unittest
from
transformers
import
AutoTokenizer
,
BertTokenizer
,
AutoTokenizer
,
GPT2Tokenizer
from
transformers
import
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from
transformers
import
(
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AutoTokenizer
,
BertTokenizer
,
GPT2Tokenizer
,
)
from
.utils
import
slow
,
SMALL_MODEL_IDENTIFIER
from
.utils
import
SMALL_MODEL_IDENTIFIER
,
slow
class
AutoTokenizerTest
(
unittest
.
TestCase
):
...
...
@@ -48,5 +50,6 @@ class AutoTokenizerTest(unittest.TestCase):
self
.
assertIsInstance
(
tokenizer
,
BertTokenizer
)
self
.
assertEqual
(
len
(
tokenizer
),
12
)
if
__name__
==
"__main__"
:
unittest
.
main
()
transformers/tests/tokenization_bert_japanese_test.py
View file @
54abc67a
...
...
@@ -15,16 +15,18 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
unittest
from
io
import
open
from
transformers.tokenization_bert
import
WordpieceTokenizer
from
transformers.tokenization_bert_japanese
import
(
BertJapaneseTokenizer
,
MecabTokenizer
,
CharacterTokenizer
,
VOCAB_FILES_NAMES
)
from
transformers.tokenization_bert_japanese
import
(
VOCAB_FILES_NAMES
,
BertJapaneseTokenizer
,
CharacterTokenizer
,
MecabTokenizer
,
)
from
.tokenization_tests_commons
import
CommonTestCases
from
.utils
import
slow
,
custom_tokenizers
from
.utils
import
custom_tokenizers
,
slow
@
custom_tokenizers
...
...
@@ -35,9 +37,24 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester):
def
setUp
(
self
):
super
(
BertJapaneseTokenizationTest
,
self
).
setUp
()
vocab_tokens
=
[
u
"[UNK]"
,
u
"[CLS]"
,
u
"[SEP]"
,
u
"こんにちは"
,
u
"こん"
,
u
"にちは"
,
u
"ばんは"
,
u
"##こん"
,
u
"##にちは"
,
u
"##ばんは"
,
u
"世界"
,
u
"##世界"
,
u
"、"
,
u
"##、"
,
u
"。"
,
u
"##。"
]
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"こんにちは"
,
"こん"
,
"にちは"
,
"ばんは"
,
"##こん"
,
"##にちは"
,
"##ばんは"
,
"世界"
,
"##世界"
,
"、"
,
"##、"
,
"。"
,
"##。"
,
]
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"vocab_file"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
vocab_writer
:
...
...
@@ -47,70 +64,63 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester):
return
BertJapaneseTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"こんにちは、世界。
\n
こんばんは、世界。"
output_text
=
u
"こんにちは 、 世界 。 こんばんは 、 世界 。"
input_text
=
"こんにちは、世界。
\n
こんばんは、世界。"
output_text
=
"こんにちは 、 世界 。 こんばんは 、 世界 。"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
self
.
tokenizer_class
(
self
.
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"こんにちは、世界。
\n
こんばんは、世界。"
)
self
.
assertListEqual
(
tokens
,
[
u
"こんにちは"
,
u
"、"
,
u
"世界"
,
u
"。"
,
u
"こん"
,
u
"##ばんは"
,
u
"、"
,
u
"世界"
,
"。"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
3
,
12
,
10
,
14
,
4
,
9
,
12
,
10
,
14
])
tokens
=
tokenizer
.
tokenize
(
"こんにちは、世界。
\n
こんばんは、世界。"
)
self
.
assertListEqual
(
tokens
,
[
"こんにちは"
,
"、"
,
"世界"
,
"。"
,
"こん"
,
"##ばんは"
,
"、"
,
"世界"
,
"。"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
3
,
12
,
10
,
14
,
4
,
9
,
12
,
10
,
14
])
def
test_mecab_tokenizer
(
self
):
tokenizer
=
MecabTokenizer
()
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
アップルストアでiPhone8 が
\n
発売された 。 "
),
[
u
"アップルストア"
,
u
"で"
,
u
"iPhone"
,
u
"8"
,
u
"が"
,
u
"発売"
,
u
"さ"
,
u
"れ"
,
u
"た"
,
u
"。"
]
)
tokenizer
.
tokenize
(
"
\t
アップルストアでiPhone8 が
\n
発売された 。 "
),
[
"アップルストア"
,
"で"
,
"iPhone"
,
"8"
,
"が"
,
"発売"
,
"さ"
,
"れ"
,
"た"
,
"。"
],
)
def
test_mecab_tokenizer_lower
(
self
):
tokenizer
=
MecabTokenizer
(
do_lower_case
=
True
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
アップルストアでiPhone8 が
\n
発売された 。 "
),
[
u
"アップルストア"
,
u
"で"
,
u
"iphone"
,
u
"8"
,
u
"が"
,
u
"発売"
,
u
"さ"
,
u
"れ"
,
u
"た"
,
u
"。"
]
)
tokenizer
.
tokenize
(
"
\t
アップルストアでiPhone8 が
\n
発売された 。 "
),
[
"アップルストア"
,
"で"
,
"iphone"
,
"8"
,
"が"
,
"発売"
,
"さ"
,
"れ"
,
"た"
,
"。"
],
)
def
test_mecab_tokenizer_no_normalize
(
self
):
tokenizer
=
MecabTokenizer
(
normalize_text
=
False
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
アップルストアでiPhone8 が
\n
発売された 。 "
),
[
u
"アップルストア"
,
u
"で"
,
u
"iPhone"
,
u
"8"
,
u
"が"
,
u
"発売"
,
u
"さ"
,
u
"れ"
,
u
"た"
,
u
" "
,
u
"。"
]
)
tokenizer
.
tokenize
(
"
\t
アップルストアでiPhone8 が
\n
発売された 。 "
),
[
"アップルストア"
,
"で"
,
"iPhone"
,
"8"
,
"が"
,
"発売"
,
"さ"
,
"れ"
,
"た"
,
" "
,
"。"
],
)
def
test_wordpiece_tokenizer
(
self
):
vocab_tokens
=
[
u
"[UNK]"
,
u
"[CLS]"
,
u
"[SEP]"
,
u
"こんにちは"
,
u
"こん"
,
u
"にちは"
u
"ばんは"
,
u
"##こん"
,
u
"##にちは"
,
u
"##ばんは"
]
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"こんにちは"
,
"こん"
,
"にちは"
"ばんは"
,
"##こん"
,
"##にちは"
,
"##ばんは"
]
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
tokenizer
=
WordpieceTokenizer
(
vocab
=
vocab
,
unk_token
=
u
"[UNK]"
)
tokenizer
=
WordpieceTokenizer
(
vocab
=
vocab
,
unk_token
=
"[UNK]"
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
""
),
[])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"こんにちは"
),
[
u
"こんにちは"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"こんにちは"
),
[
"こんにちは"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"こんばんは"
),
[
u
"こん"
,
u
"##ばんは"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"こんばんは"
),
[
"こん"
,
"##ばんは"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"こんばんは こんばんにちは こんにちは"
),
[
u
"こん"
,
u
"##ばんは"
,
u
"[UNK]"
,
u
"こんにちは"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"こんばんは こんばんにちは こんにちは"
),
[
"こん"
,
"##ばんは"
,
"[UNK]"
,
"こんにちは"
])
@
slow
def
test_sequence_builders
(
self
):
tokenizer
=
self
.
tokenizer_class
.
from_pretrained
(
"bert-base-japanese"
)
text
=
tokenizer
.
encode
(
u
"ありがとう。"
,
add_special_tokens
=
False
)
text_2
=
tokenizer
.
encode
(
u
"どういたしまして。"
,
add_special_tokens
=
False
)
text
=
tokenizer
.
encode
(
"ありがとう。"
,
add_special_tokens
=
False
)
text_2
=
tokenizer
.
encode
(
"どういたしまして。"
,
add_special_tokens
=
False
)
encoded_sentence
=
tokenizer
.
build_inputs_with_special_tokens
(
text
)
encoded_pair
=
tokenizer
.
build_inputs_with_special_tokens
(
text
,
text_2
)
...
...
@@ -127,58 +137,51 @@ class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTeste
def
setUp
(
self
):
super
(
BertJapaneseCharacterTokenizationTest
,
self
).
setUp
()
vocab_tokens
=
[
u
"[UNK]"
,
u
"[CLS]"
,
u
"[SEP]"
,
u
"こ"
,
u
"ん"
,
u
"に"
,
u
"ち"
,
u
"は"
,
u
"ば"
,
u
"世"
,
u
"界"
,
u
"、"
,
u
"。"
]
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"こ"
,
"ん"
,
"に"
,
"ち"
,
"は"
,
"ば"
,
"世"
,
"界"
,
"、"
,
"。"
]
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"vocab_file"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
def
get_tokenizer
(
self
,
**
kwargs
):
return
BertJapaneseTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
subword_tokenizer_type
=
"character"
,
**
kwargs
)
return
BertJapaneseTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
subword_tokenizer_type
=
"character"
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"こんにちは、世界。
\n
こんばんは、世界。"
output_text
=
u
"こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
input_text
=
"こんにちは、世界。
\n
こんばんは、世界。"
output_text
=
"こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
self
.
tokenizer_class
(
self
.
vocab_file
,
subword_tokenizer_type
=
"character"
)
tokenizer
=
self
.
tokenizer_class
(
self
.
vocab_file
,
subword_tokenizer_type
=
"character"
)
tokens
=
tokenizer
.
tokenize
(
u
"こんにちは、世界。
\n
こんばんは、世界。"
)
self
.
assertListEqual
(
tokens
,
[
u
"こ"
,
u
"ん"
,
u
"に"
,
u
"ち"
,
u
"は"
,
u
"、"
,
u
"世"
,
u
"界"
,
u
"。"
,
u
"こ"
,
u
"ん"
,
u
"ば"
,
u
"ん"
,
u
"は"
,
u
"、"
,
u
"世"
,
u
"界"
,
u
"。"
]
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
3
,
4
,
5
,
6
,
7
,
11
,
9
,
10
,
12
,
3
,
4
,
8
,
4
,
7
,
11
,
9
,
10
,
12
]
)
tokens
=
tokenizer
.
tokenize
(
"こんにちは、世界。
\n
こんばんは、世界。"
)
self
.
assertListEqual
(
tokens
,
[
"こ"
,
"ん"
,
"に"
,
"ち"
,
"は"
,
"、"
,
"世"
,
"界"
,
"。"
,
"こ"
,
"ん"
,
"ば"
,
"ん"
,
"は"
,
"、"
,
"世"
,
"界"
,
"。"
]
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
3
,
4
,
5
,
6
,
7
,
11
,
9
,
10
,
12
,
3
,
4
,
8
,
4
,
7
,
11
,
9
,
10
,
12
]
)
def
test_character_tokenizer
(
self
):
vocab_tokens
=
[
u
"[UNK]"
,
u
"[CLS]"
,
u
"[SEP]"
,
u
"こ"
,
u
"ん"
,
u
"に"
,
u
"ち"
,
u
"は"
,
u
"ば"
,
u
"世"
,
u
"界"
u
"、"
,
u
"。"
]
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"こ"
,
"ん"
,
"に"
,
"ち"
,
"は"
,
"ば"
,
"世"
,
"界"
"、"
,
"。"
]
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
tokenizer
=
CharacterTokenizer
(
vocab
=
vocab
,
unk_token
=
u
"[UNK]"
)
tokenizer
=
CharacterTokenizer
(
vocab
=
vocab
,
unk_token
=
"[UNK]"
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
""
),
[])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"こんにちは"
),
[
u
"こ"
,
u
"ん"
,
u
"に"
,
u
"ち"
,
u
"は"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"こんにちは"
),
[
"こ"
,
"ん"
,
"に"
,
"ち"
,
"は"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"こんにちほ"
),
[
u
"こ"
,
u
"ん"
,
u
"に"
,
u
"ち"
,
u
"[UNK]"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"こんにちほ"
),
[
"こ"
,
"ん"
,
"に"
,
"ち"
,
"[UNK]"
])
@
slow
def
test_sequence_builders
(
self
):
tokenizer
=
self
.
tokenizer_class
.
from_pretrained
(
"bert-base-japanese-char"
)
text
=
tokenizer
.
encode
(
u
"ありがとう。"
,
add_special_tokens
=
False
)
text_2
=
tokenizer
.
encode
(
u
"どういたしまして。"
,
add_special_tokens
=
False
)
text
=
tokenizer
.
encode
(
"ありがとう。"
,
add_special_tokens
=
False
)
text_2
=
tokenizer
.
encode
(
"どういたしまして。"
,
add_special_tokens
=
False
)
encoded_sentence
=
tokenizer
.
build_inputs_with_special_tokens
(
text
)
encoded_pair
=
tokenizer
.
build_inputs_with_special_tokens
(
text
,
text_2
)
...
...
@@ -186,6 +189,3 @@ class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTeste
# 2 is for "[CLS]", 3 is for "[SEP]"
assert
encoded_sentence
==
[
2
]
+
text
+
[
3
]
assert
encoded_pair
==
[
2
]
+
text
+
[
3
]
+
text_2
+
[
3
]
transformers/tests/tokenization_bert_test.py
View file @
54abc67a
...
...
@@ -18,15 +18,20 @@ import os
import
unittest
from
io
import
open
from
transformers.tokenization_bert
import
(
BasicTokenizer
,
from
transformers.tokenization_bert
import
(
VOCAB_FILES_NAMES
,
BasicTokenizer
,
BertTokenizer
,
WordpieceTokenizer
,
_is_control
,
_is_punctuation
,
_is_whitespace
,
VOCAB_FILES_NAMES
)
_is_control
,
_is_punctuation
,
_is_whitespace
,
)
from
.tokenization_tests_commons
import
CommonTestCases
from
.utils
import
slow
class
BertTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
tokenizer_class
=
BertTokenizer
...
...
@@ -35,55 +40,61 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
super
(
BertTokenizationTest
,
self
).
setUp
()
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
,
"low"
,
"lowest"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
,
"low"
,
"lowest"
,
]
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'
vocab_file
'
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
'
utf-8
'
)
as
vocab_writer
:
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"
vocab_file
"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"
utf-8
"
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
def
get_tokenizer
(
self
,
**
kwargs
):
return
BertTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"UNwant
\u00E9
d,running"
output_text
=
u
"unwanted, running"
input_text
=
"UNwant
\u00E9
d,running"
output_text
=
"unwanted, running"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
self
.
tokenizer_class
(
self
.
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
tokens
=
tokenizer
.
tokenize
(
"UNwant
\u00E9
d,running"
)
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_chinese
(
self
):
tokenizer
=
BasicTokenizer
()
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"ah
\u535A\u63A8
zz"
),
[
u
"ah"
,
u
"
\u535A
"
,
u
"
\u63A8
"
,
u
"zz"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"ah
\u535A\u63A8
zz"
),
[
"ah"
,
"
\u535A
"
,
"
\u63A8
"
,
"zz"
])
def
test_basic_tokenizer_lower
(
self
):
tokenizer
=
BasicTokenizer
(
do_lower_case
=
True
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"hello"
,
"!"
,
"how"
,
"are"
,
"you"
,
"?"
]
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
tokenizer
.
tokenize
(
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"hello"
,
"!"
,
"how"
,
"are"
,
"you"
,
"?"
]
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"H
\u00E9
llo"
),
[
"hello"
])
def
test_basic_tokenizer_no_lower
(
self
):
tokenizer
=
BasicTokenizer
(
do_lower_case
=
False
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
]
)
tokenizer
.
tokenize
(
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
]
)
def
test_wordpiece_tokenizer
(
self
):
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
]
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
]
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
...
...
@@ -92,39 +103,36 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self
.
assertListEqual
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"unwanted running"
),
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"unwanted running"
),
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"unwantedX running"
),
[
"[UNK]"
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"unwantedX running"
),
[
"[UNK]"
,
"runn"
,
"##ing"
])
def
test_is_whitespace
(
self
):
self
.
assertTrue
(
_is_whitespace
(
u
" "
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\t
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\r
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\n
"
))
self
.
assertTrue
(
_is_whitespace
(
u
"
\u00A0
"
))
self
.
assertTrue
(
_is_whitespace
(
" "
))
self
.
assertTrue
(
_is_whitespace
(
"
\t
"
))
self
.
assertTrue
(
_is_whitespace
(
"
\r
"
))
self
.
assertTrue
(
_is_whitespace
(
"
\n
"
))
self
.
assertTrue
(
_is_whitespace
(
"
\u00A0
"
))
self
.
assertFalse
(
_is_whitespace
(
u
"A"
))
self
.
assertFalse
(
_is_whitespace
(
u
"-"
))
self
.
assertFalse
(
_is_whitespace
(
"A"
))
self
.
assertFalse
(
_is_whitespace
(
"-"
))
def
test_is_control
(
self
):
self
.
assertTrue
(
_is_control
(
u
"
\u0005
"
))
self
.
assertTrue
(
_is_control
(
"
\u0005
"
))
self
.
assertFalse
(
_is_control
(
u
"A"
))
self
.
assertFalse
(
_is_control
(
u
" "
))
self
.
assertFalse
(
_is_control
(
u
"
\t
"
))
self
.
assertFalse
(
_is_control
(
u
"
\r
"
))
self
.
assertFalse
(
_is_control
(
"A"
))
self
.
assertFalse
(
_is_control
(
" "
))
self
.
assertFalse
(
_is_control
(
"
\t
"
))
self
.
assertFalse
(
_is_control
(
"
\r
"
))
def
test_is_punctuation
(
self
):
self
.
assertTrue
(
_is_punctuation
(
u
"-"
))
self
.
assertTrue
(
_is_punctuation
(
u
"$"
))
self
.
assertTrue
(
_is_punctuation
(
u
"`"
))
self
.
assertTrue
(
_is_punctuation
(
u
"."
))
self
.
assertTrue
(
_is_punctuation
(
"-"
))
self
.
assertTrue
(
_is_punctuation
(
"$"
))
self
.
assertTrue
(
_is_punctuation
(
"`"
))
self
.
assertTrue
(
_is_punctuation
(
"."
))
self
.
assertFalse
(
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
_is_punctuation
(
u
" "
))
self
.
assertFalse
(
_is_punctuation
(
"A"
))
self
.
assertFalse
(
_is_punctuation
(
" "
))
@
slow
def
test_sequence_builders
(
self
):
...
...
@@ -140,5 +148,5 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert
encoded_pair
==
[
101
]
+
text
+
[
102
]
+
text_2
+
[
102
]
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unittest
.
main
()
transformers/tests/tokenization_ctrl_test.py
View file @
54abc67a
...
...
@@ -13,15 +13,16 @@
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
os
import
unittest
import
json
from
io
import
open
from
transformers.tokenization_ctrl
import
CTRLTokenizer
,
VOCAB_FILES_NAMES
from
transformers.tokenization_ctrl
import
VOCAB_FILES_NAMES
,
CTRLTokenizer
from
.tokenization_tests_commons
import
CommonTestCases
class
CTRLTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
tokenizer_class
=
CTRLTokenizer
...
...
@@ -30,13 +31,13 @@ class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
super
(
CTRLTokenizationTest
,
self
).
setUp
()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab
=
[
'
adapt
'
,
'
re@@
'
,
'
a@@
'
,
'
apt
'
,
'
c@@
'
,
't'
,
'
<unk>
'
]
vocab
=
[
"
adapt
"
,
"
re@@
"
,
"
a@@
"
,
"
apt
"
,
"
c@@
"
,
"t"
,
"
<unk>
"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
'
a p
'
,
'
ap t</w>
'
,
'
r e
'
,
'
a d
'
,
'
ad apt</w>
'
,
''
]
merges
=
[
"#version: 0.2"
,
"
a p
"
,
"
ap t</w>
"
,
"
r e
"
,
"
a d
"
,
"
ad apt</w>
"
,
""
]
self
.
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'
vocab_file
'
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'
merges_file
'
])
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"
vocab_file
"
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"
merges_file
"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
)
+
"
\n
"
)
with
open
(
self
.
merges_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
...
...
@@ -47,23 +48,22 @@ class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
return
CTRLTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"adapt react readapt apt"
output_text
=
u
"adapt react readapt apt"
input_text
=
"adapt react readapt apt"
output_text
=
"adapt react readapt apt"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
CTRLTokenizer
(
self
.
vocab_file
,
self
.
merges_file
,
**
self
.
special_tokens_map
)
text
=
"adapt react readapt apt"
bpe_tokens
=
'
adapt re@@ a@@ c@@ t re@@ adapt apt
'
.
split
()
bpe_tokens
=
"
adapt re@@ a@@ c@@ t re@@ adapt apt
"
.
split
()
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_bpe_tokens
=
[
0
,
1
,
2
,
4
,
5
,
1
,
0
,
3
,
6
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unittest
.
main
()
transformers/tests/tokenization_distilbert_test.py
View file @
54abc67a
...
...
@@ -14,16 +14,14 @@
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
unittest
from
io
import
open
from
transformers.tokenization_distilbert
import
(
DistilBertTokenizer
)
from
transformers.tokenization_distilbert
import
DistilBertTokenizer
from
.tokenization_tests_commons
import
CommonTestCases
from
.tokenization_bert_test
import
BertTokenizationTest
from
.utils
import
slow
class
DistilBertTokenizationTest
(
BertTokenizationTest
):
tokenizer_class
=
DistilBertTokenizer
...
...
@@ -42,9 +40,10 @@ class DistilBertTokenizationTest(BertTokenizationTest):
encoded_pair
=
tokenizer
.
build_inputs_with_special_tokens
(
text
,
text_2
)
assert
encoded_sentence
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
assert
encoded_pair
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
+
\
text_2
+
[
tokenizer
.
sep_token_id
]
assert
encoded_pair
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
+
text_2
+
[
tokenizer
.
sep_token_id
]
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unittest
.
main
()
transformers/tests/tokenization_gpt2_test.py
View file @
54abc67a
...
...
@@ -14,15 +14,16 @@
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
os
import
unittest
import
json
from
io
import
open
from
transformers.tokenization_gpt2
import
GPT2Tokenizer
,
VOCAB_FILES_NAMES
from
transformers.tokenization_gpt2
import
VOCAB_FILES_NAMES
,
GPT2Tokenizer
from
.tokenization_tests_commons
import
CommonTestCases
class
GPT2TokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
tokenizer_class
=
GPT2Tokenizer
...
...
@@ -31,16 +32,34 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
super
(
GPT2TokenizationTest
,
self
).
setUp
()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"
\u0120
"
,
"
\u0120
l"
,
"
\u0120
n"
,
"
\u0120
lo"
,
"
\u0120
low"
,
"er"
,
"
\u0120
lowest"
,
"
\u0120
newer"
,
"
\u0120
wider"
,
"<unk>"
]
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"
\u0120
"
,
"
\u0120
l"
,
"
\u0120
n"
,
"
\u0120
lo"
,
"
\u0120
low"
,
"er"
,
"
\u0120
lowest"
,
"
\u0120
newer"
,
"
\u0120
wider"
,
"<unk>"
,
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"
\u0120
l"
,
"
\u0120
l o"
,
"
\u0120
lo w"
,
"e r"
,
""
]
self
.
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'
vocab_file
'
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'
merges_file
'
])
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"
vocab_file
"
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"
merges_file
"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
)
+
"
\n
"
)
with
open
(
self
.
merges_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
...
...
@@ -51,8 +70,8 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
return
GPT2Tokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"lower newer"
output_text
=
u
"lower newer"
input_text
=
"lower newer"
output_text
=
"lower newer"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
...
...
@@ -64,8 +83,8 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_bpe_tokens
=
[
14
,
15
,
10
,
9
,
3
,
2
,
15
,
19
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unittest
.
main
()
transformers/tests/tokenization_openai_test.py
View file @
54abc67a
...
...
@@ -14,11 +14,11 @@
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
os
import
unittest
import
json
from
transformers.tokenization_openai
import
OpenAIGPTTokenizer
,
VOCAB_FILES_NAMES
from
transformers.tokenization_openai
import
VOCAB_FILES_NAMES
,
OpenAIGPTTokenizer
from
.tokenization_tests_commons
import
CommonTestCases
...
...
@@ -31,15 +31,34 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
super
(
OpenAIGPTTokenizationTest
,
self
).
setUp
()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"w</w>"
,
"r</w>"
,
"t</w>"
,
"lo"
,
"low"
,
"er</w>"
,
"low</w>"
,
"lowest</w>"
,
"newer</w>"
,
"wider</w>"
,
"<unk>"
]
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"w</w>"
,
"r</w>"
,
"t</w>"
,
"lo"
,
"low"
,
"er</w>"
,
"low</w>"
,
"lowest</w>"
,
"newer</w>"
,
"wider</w>"
,
"<unk>"
,
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r</w>"
,
""
]
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'
vocab_file
'
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'
merges_file
'
])
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"
vocab_file
"
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"
merges_file
"
])
with
open
(
self
.
vocab_file
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
))
with
open
(
self
.
merges_file
,
"w"
)
as
fp
:
...
...
@@ -49,11 +68,10 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
return
OpenAIGPTTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"lower newer"
output_text
=
u
"lower newer"
input_text
=
"lower newer"
output_text
=
"lower newer"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
OpenAIGPTTokenizer
(
self
.
vocab_file
,
self
.
merges_file
)
...
...
@@ -64,9 +82,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens
=
tokens
+
[
"<unk>"
]
input_bpe_tokens
=
[
14
,
15
,
20
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unittest
.
main
()
transformers/tests/tokenization_roberta_test.py
View file @
54abc67a
...
...
@@ -14,12 +14,13 @@
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
json
import
os
import
unittest
from
io
import
open
from
transformers.tokenization_roberta
import
RobertaTokenizer
,
VOCAB_FILES_NAMES
from
transformers.tokenization_roberta
import
VOCAB_FILES_NAMES
,
RobertaTokenizer
from
.tokenization_tests_commons
import
CommonTestCases
from
.utils
import
slow
...
...
@@ -31,16 +32,34 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
super
(
RobertaTokenizationTest
,
self
).
setUp
()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"
\u0120
"
,
"
\u0120
l"
,
"
\u0120
n"
,
"
\u0120
lo"
,
"
\u0120
low"
,
"er"
,
"
\u0120
lowest"
,
"
\u0120
newer"
,
"
\u0120
wider"
,
"<unk>"
]
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"
\u0120
"
,
"
\u0120
l"
,
"
\u0120
n"
,
"
\u0120
lo"
,
"
\u0120
low"
,
"er"
,
"
\u0120
lowest"
,
"
\u0120
newer"
,
"
\u0120
wider"
,
"<unk>"
,
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"
\u0120
l"
,
"
\u0120
l o"
,
"
\u0120
lo w"
,
"e r"
,
""
]
self
.
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'
vocab_file
'
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'
merges_file
'
])
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"
vocab_file
"
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
"
merges_file
"
])
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
)
+
"
\n
"
)
with
open
(
self
.
merges_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
...
...
@@ -51,8 +70,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
return
RobertaTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"lower newer"
output_text
=
u
"lower newer"
input_text
=
"lower newer"
output_text
=
"lower newer"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
...
...
@@ -64,19 +83,15 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_bpe_tokens
=
[
14
,
15
,
10
,
9
,
3
,
2
,
15
,
19
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
def
roberta_dict_integration_testing
(
self
):
tokenizer
=
self
.
get_tokenizer
()
self
.
assertListEqual
(
tokenizer
.
encode
(
"Hello world!"
,
add_special_tokens
=
False
),
[
0
,
31414
,
232
,
328
,
2
])
self
.
assertListEqual
(
tokenizer
.
encode
(
'Hello world!'
,
add_special_tokens
=
False
),
[
0
,
31414
,
232
,
328
,
2
]
)
self
.
assertListEqual
(
tokenizer
.
encode
(
'Hello world! cécé herlolip 418'
,
add_special_tokens
=
False
),
[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]
tokenizer
.
encode
(
"Hello world! cécé herlolip 418"
,
add_special_tokens
=
False
),
[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
],
)
@
slow
...
...
@@ -87,7 +102,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
text_2
=
tokenizer
.
encode
(
"multi-sequence build"
,
add_special_tokens
=
False
)
encoded_text_from_decode
=
tokenizer
.
encode
(
"sequence builders"
,
add_special_tokens
=
True
)
encoded_pair_from_decode
=
tokenizer
.
encode
(
"sequence builders"
,
"multi-sequence build"
,
add_special_tokens
=
True
)
encoded_pair_from_decode
=
tokenizer
.
encode
(
"sequence builders"
,
"multi-sequence build"
,
add_special_tokens
=
True
)
encoded_sentence
=
tokenizer
.
build_inputs_with_special_tokens
(
text
)
encoded_pair
=
tokenizer
.
build_inputs_with_special_tokens
(
text
,
text_2
)
...
...
@@ -96,5 +113,5 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert
encoded_pair
==
encoded_pair_from_decode
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unittest
.
main
()
transformers/tests/tokenization_t5_test.py
View file @
54abc67a
...
...
@@ -17,13 +17,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
unittest
from
transformers.tokenization_t5
import
(
T5Tokenizer
)
from
transformers.tokenization_t5
import
T5Tokenizer
from
transformers.tokenization_xlnet
import
SPIECE_UNDERLINE
from
.tokenization_tests_commons
import
CommonTestCases
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'fixtures/test_sentencepiece.model'
)
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
class
T5TokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
...
...
@@ -40,38 +41,76 @@ class T5TokenizationTest(CommonTestCases.CommonTokenizerTester):
return
T5Tokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"This is a test"
output_text
=
u
"This is a test"
input_text
=
"This is a test"
output_text
=
"This is a test"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
T5Tokenizer
(
SAMPLE_VOCAB
)
tokens
=
tokenizer
.
tokenize
(
u
'
This is a test
'
)
self
.
assertListEqual
(
tokens
,
[
u
'
▁This
'
,
u
'
▁is
'
,
u
'
▁a
'
,
u
'
▁t
'
,
u
'
est
'
])
tokens
=
tokenizer
.
tokenize
(
"
This is a test
"
)
self
.
assertListEqual
(
tokens
,
[
"
▁This
"
,
"
▁is
"
,
"
▁a
"
,
"
▁t
"
,
"
est
"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
285
,
46
,
10
,
170
,
382
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
285
,
46
,
10
,
170
,
382
])
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
u
'I'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
u
'or'
,
u
'n'
,
SPIECE_UNDERLINE
+
u
'in'
,
SPIECE_UNDERLINE
+
u
''
,
u
'9'
,
u
'2'
,
u
'0'
,
u
'0'
,
u
'0'
,
u
','
,
SPIECE_UNDERLINE
+
u
'and'
,
SPIECE_UNDERLINE
+
u
'this'
,
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
u
'é'
,
u
'.'
])
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
tokens
=
tokenizer
.
tokenize
(
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
ids
,
[
8
,
21
,
84
,
55
,
24
,
19
,
7
,
0
,
602
,
347
,
347
,
347
,
3
,
12
,
66
,
46
,
72
,
80
,
6
,
0
,
4
])
tokens
,
[
SPIECE_UNDERLINE
+
"I"
,
SPIECE_UNDERLINE
+
"was"
,
SPIECE_UNDERLINE
+
"b"
,
"or"
,
"n"
,
SPIECE_UNDERLINE
+
"in"
,
SPIECE_UNDERLINE
+
""
,
"9"
,
"2"
,
"0"
,
"0"
,
"0"
,
","
,
SPIECE_UNDERLINE
+
"and"
,
SPIECE_UNDERLINE
+
"this"
,
SPIECE_UNDERLINE
+
"is"
,
SPIECE_UNDERLINE
+
"f"
,
"al"
,
"s"
,
"é"
,
"."
,
],
)
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
self
.
assertListEqual
(
ids
,
[
8
,
21
,
84
,
55
,
24
,
19
,
7
,
0
,
602
,
347
,
347
,
347
,
3
,
12
,
66
,
46
,
72
,
80
,
6
,
0
,
4
])
back_tokens
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
self
.
assertListEqual
(
back_tokens
,
[
SPIECE_UNDERLINE
+
u
'I'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
u
'or'
,
u
'n'
,
SPIECE_UNDERLINE
+
u
'in'
,
SPIECE_UNDERLINE
+
u
''
,
u
'<unk>'
,
u
'2'
,
u
'0'
,
u
'0'
,
u
'0'
,
u
','
,
SPIECE_UNDERLINE
+
u
'and'
,
SPIECE_UNDERLINE
+
u
'this'
,
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
u
'<unk>'
,
u
'.'
])
if
__name__
==
'__main__'
:
self
.
assertListEqual
(
back_tokens
,
[
SPIECE_UNDERLINE
+
"I"
,
SPIECE_UNDERLINE
+
"was"
,
SPIECE_UNDERLINE
+
"b"
,
"or"
,
"n"
,
SPIECE_UNDERLINE
+
"in"
,
SPIECE_UNDERLINE
+
""
,
"<unk>"
,
"2"
,
"0"
,
"0"
,
"0"
,
","
,
SPIECE_UNDERLINE
+
"and"
,
SPIECE_UNDERLINE
+
"this"
,
SPIECE_UNDERLINE
+
"is"
,
SPIECE_UNDERLINE
+
"f"
,
"al"
,
"s"
,
"<unk>"
,
"."
,
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment