Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
...@@ -23,10 +23,10 @@ from transformers import is_torch_available ...@@ -23,10 +23,10 @@ from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from transformers import TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel
from transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import CommonTestCases, ids_tensor
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -40,27 +40,27 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -40,27 +40,27 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
test_resize_embeddings = False test_resize_embeddings = False
class TransfoXLModelTester(object): class TransfoXLModelTester(object):
def __init__(
def __init__(self, self,
parent, parent,
batch_size=13, batch_size=13,
seq_length=7, seq_length=7,
mem_len=30, mem_len=30,
clamp_len=15, clamp_len=15,
is_training=True, is_training=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
cutoffs=[10, 50, 80], cutoffs=[10, 50, 80],
hidden_size=32, hidden_size=32,
d_embed=32, d_embed=32,
num_attention_heads=4, num_attention_heads=4,
d_head=8, d_head=8,
d_inner=128, d_inner=128,
div_val=2, div_val=2,
num_hidden_layers=5, num_hidden_layers=5,
scope=None, scope=None,
seed=1, seed=1,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
...@@ -100,7 +100,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -100,7 +100,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
d_head=self.d_head, d_head=self.d_head,
d_inner=self.d_inner, d_inner=self.d_inner,
div_val=self.div_val, div_val=self.div_val,
n_layer=self.num_hidden_layers) n_layer=self.num_hidden_layers,
)
return (config, input_ids_1, input_ids_2, lm_labels) return (config, input_ids_1, input_ids_2, lm_labels)
...@@ -125,18 +126,19 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -125,18 +126,19 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
def check_transfo_xl_model_output(self, result): def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states_1"].size()), list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size]
[self.batch_size, self.seq_length, self.hidden_size]) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states_2"].size()), list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size]
[self.batch_size, self.seq_length, self.hidden_size]) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels): def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
model = TransfoXLLMHeadModel(config) model = TransfoXLLMHeadModel(config)
...@@ -159,33 +161,30 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -159,33 +161,30 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
return outputs return outputs
def check_transfo_xl_lm_head_output(self, result): def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss_1"].size()), list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
[self.batch_size, self.seq_length]) )
self.parent.assertListEqual(
list(result["lm_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss_2"].size()), list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
[self.batch_size, self.seq_length]) )
self.parent.assertListEqual(
list(result["lm_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs (config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids_1} inputs_dict = {"input_ids": input_ids_1}
return config, inputs_dict return config, inputs_dict
def setUp(self): def setUp(self):
self.model_tester = TransfoXLModelTest.TransfoXLModelTester(self) self.model_tester = TransfoXLModelTest.TransfoXLModelTester(self)
self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37) self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
......
...@@ -21,11 +21,17 @@ import unittest ...@@ -21,11 +21,17 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
from transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, from transformers import (
XLMForSequenceClassification, XLMForQuestionAnsweringSimple) XLMConfig,
XLMModel,
XLMWithLMHeadModel,
XLMForQuestionAnswering,
XLMForSequenceClassification,
XLMForQuestionAnsweringSimple,
)
from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import CommonTestCases, ids_tensor
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -33,42 +39,50 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device ...@@ -33,42 +39,50 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
class XLMModelTest(CommonTestCases.CommonModelTester): class XLMModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, all_model_classes = (
XLMForSequenceClassification, XLMForQuestionAnsweringSimple) if is_torch_available() else () (
XLMModel,
XLMWithLMHeadModel,
XLMForQuestionAnswering,
XLMForSequenceClassification,
XLMForQuestionAnsweringSimple,
)
if is_torch_available()
else ()
)
class XLMModelTester(object): class XLMModelTester(object):
def __init__(
def __init__(self, self,
parent, parent,
batch_size=13, batch_size=13,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_input_lengths=True, use_input_lengths=True,
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
gelu_activation=True, gelu_activation=True,
sinusoidal_embeddings=False, sinusoidal_embeddings=False,
causal=False, causal=False,
asm=False, asm=False,
n_langs=2, n_langs=2,
vocab_size=99, vocab_size=99,
n_special=0, n_special=0,
hidden_size=32, hidden_size=32,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=16, type_vocab_size=16,
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4, num_choices=4,
summary_type="last", summary_type="last",
use_proj=True, use_proj=True,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
...@@ -105,7 +119,9 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -105,7 +119,9 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
input_lengths = None input_lengths = None
if self.use_input_lengths: if self.use_input_lengths:
input_lengths = ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2 # small variation of seq_length input_lengths = (
ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
) # small variation of seq_length
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
...@@ -120,31 +136,49 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -120,31 +136,49 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
is_impossible_labels = ids_tensor([self.batch_size], 2).float() is_impossible_labels = ids_tensor([self.batch_size], 2).float()
config = XLMConfig( config = XLMConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
n_special=self.n_special, n_special=self.n_special,
emb_dim=self.hidden_size, emb_dim=self.hidden_size,
n_layers=self.num_hidden_layers, n_layers=self.num_hidden_layers,
n_heads=self.num_attention_heads, n_heads=self.num_attention_heads,
dropout=self.hidden_dropout_prob, dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob, attention_dropout=self.attention_probs_dropout_prob,
gelu_activation=self.gelu_activation, gelu_activation=self.gelu_activation,
sinusoidal_embeddings=self.sinusoidal_embeddings, sinusoidal_embeddings=self.sinusoidal_embeddings,
asm=self.asm, asm=self.asm,
causal=self.causal, causal=self.causal,
n_langs=self.n_langs, n_langs=self.n_langs,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
summary_type=self.summary_type, summary_type=self.summary_type,
use_proj=self.use_proj) use_proj=self.use_proj,
)
return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask
return (
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
)
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(list(result["loss"].size()), [])
list(result["loss"].size()),
[]) def create_and_check_xlm_model(
self,
def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = XLMModel(config=config) model = XLMModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -156,11 +190,20 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -156,11 +190,20 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"sequence_output": sequence_output, "sequence_output": sequence_output,
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
[self.batch_size, self.seq_length, self.hidden_size]) )
def create_and_check_xlm_lm_head(
def create_and_check_xlm_lm_head(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = XLMWithLMHeadModel(config) model = XLMWithLMHeadModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -172,23 +215,29 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -172,23 +215,29 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"logits": logits, "logits": logits,
} }
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
[]) )
self.parent.assertListEqual(
list(result["logits"].size()), def create_and_check_xlm_simple_qa(
[self.batch_size, self.seq_length, self.vocab_size]) self,
config,
input_ids,
def create_and_check_xlm_simple_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = XLMForQuestionAnsweringSimple(config) model = XLMForQuestionAnsweringSimple(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs = model(input_ids) outputs = model(input_ids)
outputs = model(input_ids, start_positions=sequence_labels, outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
end_positions=sequence_labels)
loss, start_logits, end_logits = outputs loss, start_logits, end_logits = outputs
result = { result = {
...@@ -196,16 +245,21 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -196,16 +245,21 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"start_logits": start_logits, "start_logits": start_logits,
"end_logits": end_logits, "end_logits": end_logits,
} }
self.parent.assertListEqual( self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
list(result["start_logits"].size()), self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_xlm_qa(
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = XLMForQuestionAnswering(config) model = XLMForQuestionAnswering(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -213,21 +267,26 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -213,21 +267,26 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
outputs = model(input_ids) outputs = model(input_ids)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
outputs = model(input_ids, start_positions=sequence_labels, outputs = model(
end_positions=sequence_labels, input_ids,
cls_index=sequence_labels, start_positions=sequence_labels,
is_impossible=is_impossible_labels, end_positions=sequence_labels,
p_mask=input_mask) cls_index=sequence_labels,
is_impossible=is_impossible_labels,
outputs = model(input_ids, start_positions=sequence_labels, p_mask=input_mask,
end_positions=sequence_labels, )
cls_index=sequence_labels,
is_impossible=is_impossible_labels) outputs = model(
input_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels,
)
(total_loss,) = outputs (total_loss,) = outputs
outputs = model(input_ids, start_positions=sequence_labels, outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
end_positions=sequence_labels)
(total_loss,) = outputs (total_loss,) = outputs
...@@ -240,27 +299,34 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -240,27 +299,34 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"cls_logits": cls_logits, "cls_logits": cls_logits,
} }
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
[]) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_top_log_probs"].size()), list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
[self.batch_size, model.config.start_n_top]) )
self.parent.assertListEqual(
list(result["start_top_index"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_top_log_probs"].size()), list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top]) [self.batch_size, model.config.start_n_top * model.config.end_n_top],
)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_top_index"].size()), list(result["end_top_index"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top]) [self.batch_size, model.config.start_n_top * model.config.end_n_top],
self.parent.assertListEqual( )
list(result["cls_logits"].size()), self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])
[self.batch_size])
def create_and_check_xlm_sequence_classif(
self,
def create_and_check_xlm_sequence_classif(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = XLMForSequenceClassification(config) model = XLMForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -273,19 +339,24 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -273,19 +339,24 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
"logits": logits, "logits": logits,
} }
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
[]) )
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.type_sequence_label_size])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_lengths, (
sequence_labels, token_labels, is_impossible_labels, input_mask) = config_and_inputs config,
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths} input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "lengths": input_lengths}
return config, inputs_dict return config, inputs_dict
def setUp(self): def setUp(self):
......
...@@ -26,11 +26,17 @@ from transformers import is_torch_available ...@@ -26,11 +26,17 @@ from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, from transformers import (
XLNetForTokenClassification, XLNetForQuestionAnswering) XLNetConfig,
XLNetModel,
XLNetLMHeadModel,
XLNetForSequenceClassification,
XLNetForTokenClassification,
XLNetForQuestionAnswering,
)
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import CommonTestCases, ids_tensor
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -38,35 +44,44 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device ...@@ -38,35 +44,44 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
class XLNetModelTest(CommonTestCases.CommonModelTester): class XLNetModelTest(CommonTestCases.CommonModelTester):
all_model_classes=(XLNetModel, XLNetLMHeadModel, XLNetForTokenClassification, all_model_classes = (
XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available() else () (
XLNetModel,
XLNetLMHeadModel,
XLNetForTokenClassification,
XLNetForSequenceClassification,
XLNetForQuestionAnswering,
)
if is_torch_available()
else ()
)
test_pruning = False test_pruning = False
class XLNetModelTester(object): class XLNetModelTester(object):
def __init__(
def __init__(self, self,
parent, parent,
batch_size=13, batch_size=13,
seq_length=7, seq_length=7,
mem_len=10, mem_len=10,
clamp_len=-1, clamp_len=-1,
reuse_len=15, reuse_len=15,
is_training=True, is_training=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
cutoffs=[10, 50, 80], cutoffs=[10, 50, 80],
hidden_size=32, hidden_size=32,
num_attention_heads=4, num_attention_heads=4,
d_inner=128, d_inner=128,
num_hidden_layers=5, num_hidden_layers=5,
type_sequence_label_size=2, type_sequence_label_size=2,
untie_r=True, untie_r=True,
bi_data=False, bi_data=False,
same_length=False, same_length=False,
initializer_range=0.05, initializer_range=0.05,
seed=1, seed=1,
type_vocab_size=2, type_vocab_size=2,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
...@@ -97,9 +112,13 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -97,9 +112,13 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float() input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size) input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device) perm_mask = torch.zeros(
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device
)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device) target_mapping = torch.zeros(
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device
)
target_mapping[:, 0, -1] = 1.0 # predict last token target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels = None sequence_labels = None
...@@ -125,17 +144,43 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -125,17 +144,43 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
reuse_len=self.reuse_len, reuse_len=self.reuse_len,
bi_data=self.bi_data, bi_data=self.bi_data,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
num_labels=self.type_sequence_label_size) num_labels=self.type_sequence_label_size,
)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels) return (
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
)
def set_seed(self): def set_seed(self):
random.seed(self.seed) random.seed(self.seed)
torch.manual_seed(self.seed) torch.manual_seed(self.seed)
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_base_model(
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetModel(config) model = XLNetModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -158,14 +203,28 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -158,14 +203,28 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["outputs"].size()), list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size]
[self.batch_size, self.seq_length, self.hidden_size]) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): def create_and_check_xlnet_base_model_with_att_output(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetModel(config) model = XLNetModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -177,15 +236,30 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -177,15 +236,30 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
self.parent.assertEqual(len(attentions[0]), 2) self.parent.assertEqual(len(attentions[0]), 2)
self.parent.assertTrue(attentions[0][0].shape, attentions[0][0].shape) self.parent.assertTrue(attentions[0][0].shape, attentions[0][0].shape)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_lm_head(
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetLMHeadModel(config) model = XLNetLMHeadModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels) loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1) loss_2, all_logits_2, mems_2 = model(
input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1
)
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping) logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
...@@ -198,28 +272,39 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -198,28 +272,39 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"all_logits_2": all_logits_2, "all_logits_2": all_logits_2,
} }
self.parent.assertListEqual(list(result["loss_1"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss_1"].size()), list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
[]) )
self.parent.assertListEqual(
list(result["all_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
self.parent.assertListEqual(list(result["loss_2"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss_2"].size()), list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
[]) )
self.parent.assertListEqual(
list(result["all_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): def create_and_check_xlnet_qa(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForQuestionAnswering(config) model = XLNetForQuestionAnswering(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -227,21 +312,26 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -227,21 +312,26 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
outputs = model(input_ids_1) outputs = model(input_ids_1)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels, outputs = model(
end_positions=sequence_labels, input_ids_1,
cls_index=sequence_labels, start_positions=sequence_labels,
is_impossible=is_impossible_labels, end_positions=sequence_labels,
p_mask=input_mask) cls_index=sequence_labels,
is_impossible=is_impossible_labels,
outputs = model(input_ids_1, start_positions=sequence_labels, p_mask=input_mask,
end_positions=sequence_labels, )
cls_index=sequence_labels,
is_impossible=is_impossible_labels) outputs = model(
input_ids_1,
start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels,
)
total_loss, mems = outputs total_loss, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels, outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels)
end_positions=sequence_labels)
total_loss, mems = outputs total_loss, mems = outputs
...@@ -255,30 +345,42 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -255,30 +345,42 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"mems": mems, "mems": mems,
} }
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
[]) )
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_top_log_probs"].size()), list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
[self.batch_size, model.config.start_n_top]) )
self.parent.assertListEqual(
list(result["start_top_index"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_top_log_probs"].size()), list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top]) [self.batch_size, model.config.start_n_top * model.config.end_n_top],
)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_top_index"].size()), list(result["end_top_index"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top]) [self.batch_size, model.config.start_n_top * model.config.end_n_top],
self.parent.assertListEqual( )
list(result["cls_logits"].size()), self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])
[self.batch_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]), list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def create_and_check_xlnet_token_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): def create_and_check_xlnet_token_classif(
self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForTokenClassification(config) model = XLNetForTokenClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -292,26 +394,48 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -292,26 +394,48 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"logits": logits, "logits": logits,
} }
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size]
[]) )
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.type_sequence_label_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, (
target_mapping, segment_ids, lm_labels, config,
sequence_labels, is_impossible_labels) = config_and_inputs input_ids_1,
inputs_dict = {'input_ids': input_ids_1} input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids_1}
return config, inputs_dict return config, inputs_dict
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_sequence_classif(
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): self,
config,
input_ids_1,
input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
):
model = XLNetForSequenceClassification(config) model = XLNetForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -325,25 +449,34 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -325,25 +449,34 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"logits": logits, "logits": logits,
} }
self.parent.assertListEqual(list(result["loss"].size()), [])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
[]) )
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.type_sequence_label_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, (
target_mapping, segment_ids, lm_labels, config,
sequence_labels, is_impossible_labels, token_labels) = config_and_inputs input_ids_1,
inputs_dict = {'input_ids': input_ids_1} input_ids_2,
input_ids_q,
perm_mask,
input_mask,
target_mapping,
segment_ids,
lm_labels,
sequence_labels,
is_impossible_labels,
token_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids_1}
return config, inputs_dict return config, inputs_dict
def setUp(self): def setUp(self):
self.model_tester = XLNetModelTest.XLNetModelTester(self) self.model_tester = XLNetModelTest.XLNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37) self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
......
...@@ -24,12 +24,14 @@ from transformers import is_torch_available ...@@ -24,12 +24,14 @@ from transformers import is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import (AdamW, from transformers import (
get_constant_schedule, AdamW,
get_constant_schedule_with_warmup, get_constant_schedule,
get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup, get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup) get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from .tokenization_tests_commons import TemporaryDirectory from .tokenization_tests_commons import TemporaryDirectory
from .utils import require_torch from .utils import require_torch
...@@ -42,6 +44,7 @@ def unwrap_schedule(scheduler, num_steps=10): ...@@ -42,6 +44,7 @@ def unwrap_schedule(scheduler, num_steps=10):
lrs.append(scheduler.get_lr()) lrs.append(scheduler.get_lr())
return lrs return lrs
def unwrap_and_save_reload_schedule(scheduler, num_steps=10): def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
lrs = [] lrs = []
for step in range(num_steps): for step in range(num_steps):
...@@ -49,16 +52,16 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10): ...@@ -49,16 +52,16 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
lrs.append(scheduler.get_lr()) lrs.append(scheduler.get_lr())
if step == num_steps // 2: if step == num_steps // 2:
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
file_name = os.path.join(tmpdirname, 'schedule.bin') file_name = os.path.join(tmpdirname, "schedule.bin")
torch.save(scheduler.state_dict(), file_name) torch.save(scheduler.state_dict(), file_name)
state_dict = torch.load(file_name) state_dict = torch.load(file_name)
scheduler.load_state_dict(state_dict) scheduler.load_state_dict(state_dict)
return lrs return lrs
@require_torch @require_torch
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
def assertListAlmostEqual(self, list1, list2, tol): def assertListAlmostEqual(self, list1, list2, tol):
self.assertEqual(len(list1), len(list2)) self.assertEqual(len(list1), len(list2))
for a, b in zip(list1, list2): for a, b in zip(list1, list2):
...@@ -74,7 +77,7 @@ class OptimizationTest(unittest.TestCase): ...@@ -74,7 +77,7 @@ class OptimizationTest(unittest.TestCase):
loss = criterion(w, target) loss = criterion(w, target)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
w.grad.zero_() w.grad.zero_()
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
...@@ -82,7 +85,7 @@ class OptimizationTest(unittest.TestCase): ...@@ -82,7 +85,7 @@ class OptimizationTest(unittest.TestCase):
@require_torch @require_torch
class ScheduleInitTest(unittest.TestCase): class ScheduleInitTest(unittest.TestCase):
m = torch.nn.Linear(50, 50) if is_torch_available() else None m = torch.nn.Linear(50, 50) if is_torch_available() else None
optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
num_steps = 10 num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol): def assertListAlmostEqual(self, list1, list2, tol):
...@@ -93,7 +96,7 @@ class ScheduleInitTest(unittest.TestCase): ...@@ -93,7 +96,7 @@ class ScheduleInitTest(unittest.TestCase):
def test_constant_scheduler(self): def test_constant_scheduler(self):
scheduler = get_constant_schedule(self.optimizer) scheduler = get_constant_schedule(self.optimizer)
lrs = unwrap_schedule(scheduler, self.num_steps) lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [10.] * self.num_steps expected_learning_rates = [10.0] * self.num_steps
self.assertEqual(len(lrs[0]), 1) self.assertEqual(len(lrs[0]), 1)
self.assertListEqual([l[0] for l in lrs], expected_learning_rates) self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
...@@ -135,13 +138,17 @@ class ScheduleInitTest(unittest.TestCase): ...@@ -135,13 +138,17 @@ class ScheduleInitTest(unittest.TestCase):
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
def test_warmup_cosine_hard_restart_scheduler(self): def test_warmup_cosine_hard_restart_scheduler(self):
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10
)
lrs = unwrap_schedule(scheduler, self.num_steps) lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
self.assertEqual(len(lrs[0]), 1) self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10
)
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
......
...@@ -12,7 +12,7 @@ if is_tf_available(): ...@@ -12,7 +12,7 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from transformers import (create_optimizer, GradientAccumulator) from transformers import create_optimizer, GradientAccumulator
@require_tf @require_tf
...@@ -21,7 +21,7 @@ class OptimizationFTest(unittest.TestCase): ...@@ -21,7 +21,7 @@ class OptimizationFTest(unittest.TestCase):
self.assertEqual(len(list1), len(list2)) self.assertEqual(len(list1), len(list2))
for a, b in zip(list1, list2): for a, b in zip(list1, list2):
self.assertAlmostEqual(a, b, delta=tol) self.assertAlmostEqual(a, b, delta=tol)
def testGradientAccumulator(self): def testGradientAccumulator(self):
accumulator = GradientAccumulator() accumulator = GradientAccumulator()
accumulator([tf.constant([1.0, 2.0])]) accumulator([tf.constant([1.0, 2.0])])
...@@ -42,8 +42,8 @@ class OptimizationFTest(unittest.TestCase): ...@@ -42,8 +42,8 @@ class OptimizationFTest(unittest.TestCase):
physical_devices = tf.config.experimental.list_physical_devices("CPU") physical_devices = tf.config.experimental.list_physical_devices("CPU")
tf.config.experimental.set_virtual_device_configuration( tf.config.experimental.set_virtual_device_configuration(
physical_devices[0], physical_devices[0],
[tf.config.experimental.VirtualDeviceConfiguration(), [tf.config.experimental.VirtualDeviceConfiguration(), tf.config.experimental.VirtualDeviceConfiguration()],
tf.config.experimental.VirtualDeviceConfiguration()]) )
devices = tf.config.experimental.list_logical_devices(device_type="CPU") devices = tf.config.experimental.list_logical_devices(device_type="CPU")
strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices]) strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices])
...@@ -87,4 +87,4 @@ class OptimizationFTest(unittest.TestCase): ...@@ -87,4 +87,4 @@ class OptimizationFTest(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
\ No newline at end of file
...@@ -6,58 +6,58 @@ from transformers import pipeline ...@@ -6,58 +6,58 @@ from transformers import pipeline
from transformers.tests.utils import require_tf, require_torch from transformers.tests.utils import require_tf, require_torch
QA_FINETUNED_MODELS = { QA_FINETUNED_MODELS = {
('bert-base-uncased', 'bert-large-uncased-whole-word-masking-finetuned-squad', None), ("bert-base-uncased", "bert-large-uncased-whole-word-masking-finetuned-squad", None),
('bert-base-cased', 'bert-large-cased-whole-word-masking-finetuned-squad', None), ("bert-base-cased", "bert-large-cased-whole-word-masking-finetuned-squad", None),
('bert-base-uncased', 'distilbert-base-uncased-distilled-squad', None) ("bert-base-uncased", "distilbert-base-uncased-distilled-squad", None),
} }
TF_QA_FINETUNED_MODELS = { TF_QA_FINETUNED_MODELS = {
('bert-base-uncased', 'bert-large-uncased-whole-word-masking-finetuned-squad', None), ("bert-base-uncased", "bert-large-uncased-whole-word-masking-finetuned-squad", None),
('bert-base-cased', 'bert-large-cased-whole-word-masking-finetuned-squad', None), ("bert-base-cased", "bert-large-cased-whole-word-masking-finetuned-squad", None),
('bert-base-uncased', 'distilbert-base-uncased-distilled-squad', None) ("bert-base-uncased", "distilbert-base-uncased-distilled-squad", None),
} }
TF_NER_FINETUNED_MODELS = { TF_NER_FINETUNED_MODELS = {
( (
'bert-base-cased', "bert-base-cased",
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-tf_model.h5', "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-tf_model.h5",
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json' "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json",
) )
} }
NER_FINETUNED_MODELS = { NER_FINETUNED_MODELS = {
( (
'bert-base-cased', "bert-base-cased",
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin', "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin",
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json' "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json",
) )
} }
FEATURE_EXTRACT_FINETUNED_MODELS = { FEATURE_EXTRACT_FINETUNED_MODELS = {
('bert-base-cased', 'bert-base-cased', None), ("bert-base-cased", "bert-base-cased", None),
# ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2 # ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2
('distilbert-base-uncased', 'distilbert-base-uncased', None) ("distilbert-base-uncased", "distilbert-base-uncased", None),
} }
TF_FEATURE_EXTRACT_FINETUNED_MODELS = { TF_FEATURE_EXTRACT_FINETUNED_MODELS = {
('bert-base-cased', 'bert-base-cased', None), ("bert-base-cased", "bert-base-cased", None),
# ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2 # ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2
('distilbert-base-uncased', 'distilbert-base-uncased', None) ("distilbert-base-uncased", "distilbert-base-uncased", None),
} }
TF_TEXT_CLASSIF_FINETUNED_MODELS = { TF_TEXT_CLASSIF_FINETUNED_MODELS = {
( (
'bert-base-uncased', "bert-base-uncased",
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-tf_model.h5', "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-tf_model.h5",
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json' "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json",
) )
} }
TEXT_CLASSIF_FINETUNED_MODELS = { TEXT_CLASSIF_FINETUNED_MODELS = {
( (
'bert-base-uncased', "bert-base-uncased",
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin', "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin",
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json' "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json",
) )
} }
...@@ -91,54 +91,54 @@ class MonoColumnInputTestCase(unittest.TestCase): ...@@ -91,54 +91,54 @@ class MonoColumnInputTestCase(unittest.TestCase):
@require_torch @require_torch
def test_ner(self): def test_ner(self):
mandatory_keys = {'entity', 'word', 'score'} mandatory_keys = {"entity", "word", "score"}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris'] valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None] invalid_inputs = [None]
for tokenizer, model, config in NER_FINETUNED_MODELS: for tokenizer, model, config in NER_FINETUNED_MODELS:
nlp = pipeline(task='ner', model=model, config=config, tokenizer=tokenizer) nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys) self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
@require_tf @require_tf
def test_tf_ner(self): def test_tf_ner(self):
mandatory_keys = {'entity', 'word', 'score'} mandatory_keys = {"entity", "word", "score"}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris'] valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None] invalid_inputs = [None]
for tokenizer, model, config in TF_NER_FINETUNED_MODELS: for tokenizer, model, config in TF_NER_FINETUNED_MODELS:
nlp = pipeline(task='ner', model=model, config=config, tokenizer=tokenizer) nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys) self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
@require_torch @require_torch
def test_sentiment_analysis(self): def test_sentiment_analysis(self):
mandatory_keys = {'label'} mandatory_keys = {"label"}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris'] valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None] invalid_inputs = [None]
for tokenizer, model, config in TEXT_CLASSIF_FINETUNED_MODELS: for tokenizer, model, config in TEXT_CLASSIF_FINETUNED_MODELS:
nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer) nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys) self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
@require_tf @require_tf
def test_tf_sentiment_analysis(self): def test_tf_sentiment_analysis(self):
mandatory_keys = {'label'} mandatory_keys = {"label"}
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris'] valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None] invalid_inputs = [None]
for tokenizer, model, config in TF_TEXT_CLASSIF_FINETUNED_MODELS: for tokenizer, model, config in TF_TEXT_CLASSIF_FINETUNED_MODELS:
nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer) nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys) self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
@require_torch @require_torch
def test_features_extraction(self): def test_features_extraction(self):
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris'] valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None] invalid_inputs = [None]
for tokenizer, model, config in FEATURE_EXTRACT_FINETUNED_MODELS: for tokenizer, model, config in FEATURE_EXTRACT_FINETUNED_MODELS:
nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer) nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {}) self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
@require_tf @require_tf
def test_tf_features_extraction(self): def test_tf_features_extraction(self):
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris'] valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None] invalid_inputs = [None]
for tokenizer, model, config in TF_FEATURE_EXTRACT_FINETUNED_MODELS: for tokenizer, model, config in TF_FEATURE_EXTRACT_FINETUNED_MODELS:
nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer) nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {}) self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
...@@ -165,46 +165,46 @@ class MultiColumnInputTestCase(unittest.TestCase): ...@@ -165,46 +165,46 @@ class MultiColumnInputTestCase(unittest.TestCase):
@require_torch @require_torch
def test_question_answering(self): def test_question_answering(self):
mandatory_output_keys = {'score', 'answer', 'start', 'end'} mandatory_output_keys = {"score", "answer", "start", "end"}
valid_samples = [ valid_samples = [
{'question': 'Where was HuggingFace founded ?', 'context': 'HuggingFace was founded in Paris.'}, {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
{ {
'question': 'In what field is HuggingFace working ?', "question": "In what field is HuggingFace working ?",
'context': 'HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.' "context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
} },
] ]
invalid_samples = [ invalid_samples = [
{'question': '', 'context': 'This is a test to try empty question edge case'}, {"question": "", "context": "This is a test to try empty question edge case"},
{'question': None, 'context': 'This is a test to try empty question edge case'}, {"question": None, "context": "This is a test to try empty question edge case"},
{'question': 'What is does with empty context ?', 'context': ''}, {"question": "What is does with empty context ?", "context": ""},
{'question': 'What is does with empty context ?', 'context': None}, {"question": "What is does with empty context ?", "context": None},
] ]
for tokenizer, model, config in QA_FINETUNED_MODELS: for tokenizer, model, config in QA_FINETUNED_MODELS:
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer) nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer)
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys) self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
@require_tf @require_tf
def test_tf_question_answering(self): def test_tf_question_answering(self):
mandatory_output_keys = {'score', 'answer', 'start', 'end'} mandatory_output_keys = {"score", "answer", "start", "end"}
valid_samples = [ valid_samples = [
{'question': 'Where was HuggingFace founded ?', 'context': 'HuggingFace was founded in Paris.'}, {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
{ {
'question': 'In what field is HuggingFace working ?', "question": "In what field is HuggingFace working ?",
'context': 'HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.' "context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
} },
] ]
invalid_samples = [ invalid_samples = [
{'question': '', 'context': 'This is a test to try empty question edge case'}, {"question": "", "context": "This is a test to try empty question edge case"},
{'question': None, 'context': 'This is a test to try empty question edge case'}, {"question": None, "context": "This is a test to try empty question edge case"},
{'question': 'What is does with empty context ?', 'context': ''}, {"question": "What is does with empty context ?", "context": ""},
{'question': 'What is does with empty context ?', 'context': None}, {"question": "What is does with empty context ?", "context": None},
] ]
for tokenizer, model, config in TF_QA_FINETUNED_MODELS: for tokenizer, model, config in TF_QA_FINETUNED_MODELS:
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer) nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer)
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys) self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -17,12 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,12 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from transformers.tokenization_albert import (AlbertTokenizer, SPIECE_UNDERLINE) from transformers.tokenization_albert import AlbertTokenizer, SPIECE_UNDERLINE
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/spiece.model")
'fixtures/spiece.model')
class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester): class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
...@@ -39,27 +39,30 @@ class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -39,27 +39,30 @@ class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs) return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"this is a test" input_text = "this is a test"
output_text = u"this is a test" output_text = "this is a test"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = AlbertTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = AlbertTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize(u'This is a test') tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, [u'▁this', u'▁is', u'▁a', u'▁test']) self.assertListEqual(tokens, ["▁this", "▁is", "▁a", "▁test"])
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [48, 25, 21, 1289])
tokenizer.convert_tokens_to_ids(tokens), [48, 25, 21, 1289])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [u'▁i', u'▁was', u'▁born', u'▁in', u'▁9', u'2000', u',', u'▁and', u'▁this', u'▁is', u'▁fal', u's', u'é', u'.']) self.assertListEqual(
tokens, ["▁i", "▁was", "▁born", "▁in", "▁9", "2000", ",", "▁and", "▁this", "▁is", "▁fal", "s", "é", "."]
)
ids = tokenizer.convert_tokens_to_ids(tokens) ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [31, 23, 386, 19, 561, 3050, 15, 17, 48, 25, 8256, 18, 1, 9]) self.assertListEqual(ids, [31, 23, 386, 19, 561, 3050, 15, 17, 48, 25, 8256, 18, 1, 9])
back_tokens = tokenizer.convert_ids_to_tokens(ids) back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, ['▁i', '▁was', '▁born', '▁in', '▁9', '2000', ',', '▁and', '▁this', '▁is', '▁fal', 's', '<unk>', '.']) self.assertListEqual(
back_tokens,
["▁i", "▁was", "▁born", "▁in", "▁9", "2000", ",", "▁and", "▁this", "▁is", "▁fal", "s", "<unk>", "."],
)
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = AlbertTokenizer(SAMPLE_VOCAB) tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
...@@ -71,8 +74,10 @@ class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -71,8 +74,10 @@ class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id]
assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + text_2 + [tokenizer.sep_token_id] assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + text_2 + [
tokenizer.sep_token_id
]
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -48,5 +48,6 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -48,5 +48,6 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertIsInstance(tokenizer, BertTokenizer) self.assertIsInstance(tokenizer, BertTokenizer)
self.assertEqual(len(tokenizer), 12) self.assertEqual(len(tokenizer), 12)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -19,9 +19,12 @@ import unittest ...@@ -19,9 +19,12 @@ import unittest
from io import open from io import open
from transformers.tokenization_bert import WordpieceTokenizer from transformers.tokenization_bert import WordpieceTokenizer
from transformers.tokenization_bert_japanese import (BertJapaneseTokenizer, from transformers.tokenization_bert_japanese import (
MecabTokenizer, CharacterTokenizer, BertJapaneseTokenizer,
VOCAB_FILES_NAMES) MecabTokenizer,
CharacterTokenizer,
VOCAB_FILES_NAMES,
)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import slow, custom_tokenizers from .utils import slow, custom_tokenizers
...@@ -35,9 +38,24 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -35,9 +38,24 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester):
def setUp(self): def setUp(self):
super(BertJapaneseTokenizationTest, self).setUp() super(BertJapaneseTokenizationTest, self).setUp()
vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", vocab_tokens = [
u"こんにちは", u"こん", u"にちは", u"ばんは", u"##こん", u"##にちは", u"##ばんは", "[UNK]",
u"世界", u"##世界", u"、", u"##、", u"。", u"##。"] "[CLS]",
"[SEP]",
"こんにちは",
"こん",
"にちは",
"ばんは",
"##こん",
"##にちは",
"##ばんは",
"世界",
"##世界",
"、",
"##、",
"。",
"##。",
]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
...@@ -47,70 +65,63 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -47,70 +65,63 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester):
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs) return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"こんにちは、世界。 \nこんばんは、世界。" input_text = "こんにちは、世界。 \nこんばんは、世界。"
output_text = u"こんにちは 、 世界 。 こんばんは 、 世界 。" output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file) tokenizer = self.tokenizer_class(self.vocab_file)
tokens = tokenizer.tokenize(u"こんにちは、世界。\nこんばんは、世界。") tokens = tokenizer.tokenize("こんにちは、世界。\nこんばんは、世界。")
self.assertListEqual(tokens, self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
[u"こんにちは", u"、", u"世界", u"。", self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])
u"こん", u"##ばんは", u"、", u"世界", "。"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
[3, 12, 10, 14, 4, 9, 12, 10, 14])
def test_mecab_tokenizer(self): def test_mecab_tokenizer(self):
tokenizer = MecabTokenizer() tokenizer = MecabTokenizer()
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "), tokenizer.tokenize(" \tアップルストアでiPhone8 が \n 発売された 。 "),
[u"アップルストア", u"で", u"iPhone", u"8", u"が", ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
u"発売", u"さ", u"れ", u"た", u"。"]) )
def test_mecab_tokenizer_lower(self): def test_mecab_tokenizer_lower(self):
tokenizer = MecabTokenizer(do_lower_case=True) tokenizer = MecabTokenizer(do_lower_case=True)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "), tokenizer.tokenize(" \tアップルストアでiPhone8 が \n 発売された 。 "),
[u"アップルストア", u"で", u"iphone", u"8", u"が", ["アップルストア", "で", "iphone", "8", "が", "発売", "さ", "れ", "た", "。"],
u"発売", u"さ", u"れ", u"た", u"。"]) )
def test_mecab_tokenizer_no_normalize(self): def test_mecab_tokenizer_no_normalize(self):
tokenizer = MecabTokenizer(normalize_text=False) tokenizer = MecabTokenizer(normalize_text=False)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tアップルストアでiPhone8 が \n 発売された 。 "), tokenizer.tokenize(" \tアップルストアでiPhone8 が \n 発売された 。 "),
[u"アップルストア", u"で", u"iPhone", u"8", u"が", ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", " ", "。"],
u"発売", u"さ", u"れ", u"た", u" ", u"。"]) )
def test_wordpiece_tokenizer(self): def test_wordpiece_tokenizer(self):
vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こんにちは", "こん", "にちは" "ばんは", "##こん", "##にちは", "##ばんは"]
u"こんにちは", u"こん", u"にちは" u"ばんは", u"##こん", u"##にちは", u"##ばんは"]
vocab = {} vocab = {}
for (i, token) in enumerate(vocab_tokens): for (i, token) in enumerate(vocab_tokens):
vocab[token] = i vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token=u"[UNK]") tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
self.assertListEqual(tokenizer.tokenize(u""), []) self.assertListEqual(tokenizer.tokenize(""), [])
self.assertListEqual(tokenizer.tokenize(u"こんにちは"), self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こんにちは"])
[u"こんにちは"])
self.assertListEqual(tokenizer.tokenize(u"こんばんは"), self.assertListEqual(tokenizer.tokenize("こんばんは"), ["こん", "##ばんは"])
[u"こん", u"##ばんは"])
self.assertListEqual(tokenizer.tokenize(u"こんばんは こんばんにちは こんにちは"), self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
[u"こん", u"##ばんは", u"[UNK]", u"こんにちは"])
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese") tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese")
text = tokenizer.encode(u"ありがとう。", add_special_tokens=False) text = tokenizer.encode("ありがとう。", add_special_tokens=False)
text_2 = tokenizer.encode(u"どういたしまして。", add_special_tokens=False) text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
...@@ -127,58 +138,51 @@ class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTeste ...@@ -127,58 +138,51 @@ class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTeste
def setUp(self): def setUp(self):
super(BertJapaneseCharacterTokenizationTest, self).setUp() super(BertJapaneseCharacterTokenizationTest, self).setUp()
vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界", "、", "。"]
u"こ", u"ん", u"に", u"ち", u"は", u"ば", u"世", u"界", u"、", u"。"]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, subword_tokenizer_type="character", **kwargs)
subword_tokenizer_type="character",
**kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"こんにちは、世界。 \nこんばんは、世界。" input_text = "こんにちは、世界。 \nこんばんは、世界。"
output_text = u"こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。" output_text = "こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file, tokenizer = self.tokenizer_class(self.vocab_file, subword_tokenizer_type="character")
subword_tokenizer_type="character")
tokens = tokenizer.tokenize(u"こんにちは、世界。 \nこんばんは、世界。") tokens = tokenizer.tokenize("こんにちは、世界。 \nこんばんは、世界。")
self.assertListEqual(tokens, self.assertListEqual(
[u"こ", u"ん", u"に", u"ち", u"は", u"、", u"世", u"界", u"。", tokens, ["こ", "ん", "に", "ち", "は", "、", "世", "界", "。", "こ", "ん", "ば", "ん", "は", "、", "世", "界", "。"]
u"こ", u"ん", u"ば", u"ん", u"は", u"、", u"世", u"界", u"。"]) )
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), self.assertListEqual(
[3, 4, 5, 6, 7, 11, 9, 10, 12, tokenizer.convert_tokens_to_ids(tokens), [3, 4, 5, 6, 7, 11, 9, 10, 12, 3, 4, 8, 4, 7, 11, 9, 10, 12]
3, 4, 8, 4, 7, 11, 9, 10, 12]) )
def test_character_tokenizer(self): def test_character_tokenizer(self):
vocab_tokens = [u"[UNK]", u"[CLS]", u"[SEP]", vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界" "、", "。"]
u"こ", u"ん", u"に", u"ち", u"は", u"ば", u"世", u"界"u"、", u"。"]
vocab = {} vocab = {}
for (i, token) in enumerate(vocab_tokens): for (i, token) in enumerate(vocab_tokens):
vocab[token] = i vocab[token] = i
tokenizer = CharacterTokenizer(vocab=vocab, unk_token=u"[UNK]") tokenizer = CharacterTokenizer(vocab=vocab, unk_token="[UNK]")
self.assertListEqual(tokenizer.tokenize(u""), []) self.assertListEqual(tokenizer.tokenize(""), [])
self.assertListEqual(tokenizer.tokenize(u"こんにちは"), self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こ", "ん", "に", "ち", "は"])
[u"こ", u"ん", u"に", u"ち", u"は"])
self.assertListEqual(tokenizer.tokenize(u"こんにちほ"), self.assertListEqual(tokenizer.tokenize("こんにちほ"), ["こ", "ん", "に", "ち", "[UNK]"])
[u"こ", u"ん", u"に", u"ち", u"[UNK]"])
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese-char") tokenizer = self.tokenizer_class.from_pretrained("bert-base-japanese-char")
text = tokenizer.encode(u"ありがとう。", add_special_tokens=False) text = tokenizer.encode("ありがとう。", add_special_tokens=False)
text_2 = tokenizer.encode(u"どういたしまして。", add_special_tokens=False) text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
...@@ -186,6 +190,3 @@ class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTeste ...@@ -186,6 +190,3 @@ class BertJapaneseCharacterTokenizationTest(CommonTestCases.CommonTokenizerTeste
# 2 is for "[CLS]", 3 is for "[SEP]" # 2 is for "[CLS]", 3 is for "[SEP]"
assert encoded_sentence == [2] + text + [3] assert encoded_sentence == [2] + text + [3]
assert encoded_pair == [2] + text + [3] + text_2 + [3] assert encoded_pair == [2] + text + [3] + text_2 + [3]
...@@ -18,15 +18,20 @@ import os ...@@ -18,15 +18,20 @@ import os
import unittest import unittest
from io import open from io import open
from transformers.tokenization_bert import (BasicTokenizer, from transformers.tokenization_bert import (
BertTokenizer, BasicTokenizer,
WordpieceTokenizer, BertTokenizer,
_is_control, _is_punctuation, WordpieceTokenizer,
_is_whitespace, VOCAB_FILES_NAMES) _is_control,
_is_punctuation,
_is_whitespace,
VOCAB_FILES_NAMES,
)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import slow from .utils import slow
class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = BertTokenizer tokenizer_class = BertTokenizer
...@@ -35,55 +40,61 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -35,55 +40,61 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
super(BertTokenizationTest, self).setUp() super(BertTokenizationTest, self).setUp()
vocab_tokens = [ vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]",
"##ing", ",", "low", "lowest", "[CLS]",
"[SEP]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
] ]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running" input_text = "UNwant\u00E9d,running"
output_text = u"unwanted, running" output_text = "unwanted, running"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file) tokenizer = self.tokenizer_class(self.vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
self.assertListEqual( self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
[u"ah", u"\u535A", u"\u63A8", u"zz"])
def test_basic_tokenizer_lower(self): def test_basic_tokenizer_lower(self):
tokenizer = BasicTokenizer(do_lower_case=True) tokenizer = BasicTokenizer(do_lower_case=True)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["hello", "!", "how", "are", "you", "?"]
["hello", "!", "how", "are", "you", "?"]) )
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
def test_basic_tokenizer_no_lower(self): def test_basic_tokenizer_no_lower(self):
tokenizer = BasicTokenizer(do_lower_case=False) tokenizer = BasicTokenizer(do_lower_case=False)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
["HeLLo", "!", "how", "Are", "yoU", "?"]) )
def test_wordpiece_tokenizer(self): def test_wordpiece_tokenizer(self):
vocab_tokens = [ vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
]
vocab = {} vocab = {}
for (i, token) in enumerate(vocab_tokens): for (i, token) in enumerate(vocab_tokens):
...@@ -92,39 +103,36 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -92,39 +103,36 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertListEqual(tokenizer.tokenize(""), []) self.assertListEqual(tokenizer.tokenize(""), [])
self.assertListEqual( self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"])
tokenizer.tokenize("unwanted running"),
["un", "##want", "##ed", "runn", "##ing"])
self.assertListEqual( self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
def test_is_whitespace(self): def test_is_whitespace(self):
self.assertTrue(_is_whitespace(u" ")) self.assertTrue(_is_whitespace(" "))
self.assertTrue(_is_whitespace(u"\t")) self.assertTrue(_is_whitespace("\t"))
self.assertTrue(_is_whitespace(u"\r")) self.assertTrue(_is_whitespace("\r"))
self.assertTrue(_is_whitespace(u"\n")) self.assertTrue(_is_whitespace("\n"))
self.assertTrue(_is_whitespace(u"\u00A0")) self.assertTrue(_is_whitespace("\u00A0"))
self.assertFalse(_is_whitespace(u"A")) self.assertFalse(_is_whitespace("A"))
self.assertFalse(_is_whitespace(u"-")) self.assertFalse(_is_whitespace("-"))
def test_is_control(self): def test_is_control(self):
self.assertTrue(_is_control(u"\u0005")) self.assertTrue(_is_control("\u0005"))
self.assertFalse(_is_control(u"A")) self.assertFalse(_is_control("A"))
self.assertFalse(_is_control(u" ")) self.assertFalse(_is_control(" "))
self.assertFalse(_is_control(u"\t")) self.assertFalse(_is_control("\t"))
self.assertFalse(_is_control(u"\r")) self.assertFalse(_is_control("\r"))
def test_is_punctuation(self): def test_is_punctuation(self):
self.assertTrue(_is_punctuation(u"-")) self.assertTrue(_is_punctuation("-"))
self.assertTrue(_is_punctuation(u"$")) self.assertTrue(_is_punctuation("$"))
self.assertTrue(_is_punctuation(u"`")) self.assertTrue(_is_punctuation("`"))
self.assertTrue(_is_punctuation(u".")) self.assertTrue(_is_punctuation("."))
self.assertFalse(_is_punctuation(u"A")) self.assertFalse(_is_punctuation("A"))
self.assertFalse(_is_punctuation(u" ")) self.assertFalse(_is_punctuation(" "))
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
...@@ -140,5 +148,5 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -140,5 +148,5 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert encoded_pair == [101] + text + [102] + text_2 + [102] assert encoded_pair == [101] + text + [102] + text_2 + [102]
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -22,6 +22,7 @@ from transformers.tokenization_ctrl import CTRLTokenizer, VOCAB_FILES_NAMES ...@@ -22,6 +22,7 @@ from transformers.tokenization_ctrl import CTRLTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester): class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = CTRLTokenizer tokenizer_class = CTRLTokenizer
...@@ -30,13 +31,13 @@ class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -30,13 +31,13 @@ class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
super(CTRLTokenizationTest, self).setUp() super(CTRLTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ['adapt', 're@@', 'a@@', 'apt', 'c@@', 't', '<unk>'] vocab = ["adapt", "re@@", "a@@", "apt", "c@@", "t", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", 'a p', 'ap t</w>', 'r e', 'a d', 'ad apt</w>', ''] merges = ["#version: 0.2", "a p", "ap t</w>", "r e", "a d", "ad apt</w>", ""]
self.special_tokens_map = {"unk_token": "<unk>"} self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp: with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n") fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp: with open(self.merges_file, "w", encoding="utf-8") as fp:
...@@ -47,23 +48,22 @@ class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -47,23 +48,22 @@ class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs) return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"adapt react readapt apt" input_text = "adapt react readapt apt"
output_text = u"adapt react readapt apt" output_text = "adapt react readapt apt"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = CTRLTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = CTRLTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "adapt react readapt apt" text = "adapt react readapt apt"
bpe_tokens = 'adapt re@@ a@@ c@@ t re@@ adapt apt'.split() bpe_tokens = "adapt re@@ a@@ c@@ t re@@ adapt apt".split()
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6] input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6]
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -18,12 +18,13 @@ import os ...@@ -18,12 +18,13 @@ import os
import unittest import unittest
from io import open from io import open
from transformers.tokenization_distilbert import (DistilBertTokenizer) from transformers.tokenization_distilbert import DistilBertTokenizer
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .tokenization_bert_test import BertTokenizationTest from .tokenization_bert_test import BertTokenizationTest
from .utils import slow from .utils import slow
class DistilBertTokenizationTest(BertTokenizationTest): class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer tokenizer_class = DistilBertTokenizer
...@@ -42,9 +43,10 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -42,9 +43,10 @@ class DistilBertTokenizationTest(BertTokenizationTest):
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id]
assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \ assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + text_2 + [
text_2 + [tokenizer.sep_token_id] tokenizer.sep_token_id
]
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -23,6 +23,7 @@ from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES ...@@ -23,6 +23,7 @@ from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = GPT2Tokenizer tokenizer_class = GPT2Tokenizer
...@@ -31,16 +32,34 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -31,16 +32,34 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
super(GPT2TokenizationTest, self).setUp() super(GPT2TokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = [
"\u0120", "\u0120l", "\u0120n", "l",
"\u0120lo", "\u0120low", "er", "o",
"\u0120lowest", "\u0120newer", "\u0120wider", "<unk>"] "w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"} self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp: with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n") fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp: with open(self.merges_file, "w", encoding="utf-8") as fp:
...@@ -51,8 +70,8 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -51,8 +70,8 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = "lower newer"
output_text = u"lower newer" output_text = "lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
...@@ -64,8 +83,8 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -64,8 +83,8 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -31,15 +31,34 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -31,15 +31,34 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
super(OpenAIGPTTokenizationTest, self).setUp() super(OpenAIGPTTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = [
"w</w>", "r</w>", "t</w>", "l",
"lo", "low", "er</w>", "o",
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"] "w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"w</w>",
"r</w>",
"t</w>",
"lo",
"low",
"er</w>",
"low</w>",
"lowest</w>",
"newer</w>",
"wider</w>",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""] merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w") as fp: with open(self.vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens)) fp.write(json.dumps(vocab_tokens))
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
...@@ -49,11 +68,10 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -49,11 +68,10 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = "lower newer"
output_text = u"lower newer" output_text = "lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
...@@ -64,9 +82,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -64,9 +82,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens = tokens + ["<unk>"] input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20] input_bpe_tokens = [14, 15, 20]
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -31,16 +31,34 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -31,16 +31,34 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
super(RobertaTokenizationTest, self).setUp() super(RobertaTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = [
"\u0120", "\u0120l", "\u0120n", "l",
"\u0120lo", "\u0120low", "er", "o",
"\u0120lowest", "\u0120newer", "\u0120wider", "<unk>"] "w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"} self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp: with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n") fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp: with open(self.merges_file, "w", encoding="utf-8") as fp:
...@@ -51,8 +69,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -51,8 +69,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = "lower newer"
output_text = u"lower newer" output_text = "lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
...@@ -64,19 +82,15 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -64,19 +82,15 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def roberta_dict_integration_testing(self): def roberta_dict_integration_testing(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
self.assertListEqual(tokenizer.encode("Hello world!", add_special_tokens=False), [0, 31414, 232, 328, 2])
self.assertListEqual( self.assertListEqual(
tokenizer.encode('Hello world!', add_special_tokens=False), tokenizer.encode("Hello world! cécé herlolip 418", add_special_tokens=False),
[0, 31414, 232, 328, 2] [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2],
)
self.assertListEqual(
tokenizer.encode('Hello world! cécé herlolip 418', add_special_tokens=False),
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
) )
@slow @slow
...@@ -87,7 +101,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -87,7 +101,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) encoded_pair_from_decode = tokenizer.encode(
"sequence builders", "multi-sequence build", add_special_tokens=True
)
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
...@@ -96,5 +112,5 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -96,5 +112,5 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert encoded_pair == encoded_pair_from_decode assert encoded_pair == encoded_pair_from_decode
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -17,13 +17,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,13 +17,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from transformers.tokenization_t5 import (T5Tokenizer) from transformers.tokenization_t5 import T5Tokenizer
from transformers.tokenization_xlnet import SPIECE_UNDERLINE from transformers.tokenization_xlnet import SPIECE_UNDERLINE
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
'fixtures/test_sentencepiece.model')
class T5TokenizationTest(CommonTestCases.CommonTokenizerTester): class T5TokenizationTest(CommonTestCases.CommonTokenizerTester):
...@@ -40,38 +40,76 @@ class T5TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -40,38 +40,76 @@ class T5TokenizationTest(CommonTestCases.CommonTokenizerTester):
return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs) return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"This is a test" input_text = "This is a test"
output_text = u"This is a test" output_text = "This is a test"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB) tokenizer = T5Tokenizer(SAMPLE_VOCAB)
tokens = tokenizer.tokenize(u'This is a test') tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual( self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0, tokens,
602, 347, 347, 347, 3, 12, 66, [
46, 72, 80, 6, 0, 4]) SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids) back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', self.assertListEqual(
u'or', u'n', SPIECE_UNDERLINE + u'in', back_tokens,
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',', [
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', SPIECE_UNDERLINE + "was",
u'<unk>', u'.']) SPIECE_UNDERLINE + "b",
"or",
"n",
if __name__ == '__main__': SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -26,19 +26,23 @@ if sys.version_info[0] == 2: ...@@ -26,19 +26,23 @@ if sys.version_info[0] == 2:
class TemporaryDirectory(object): class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement.""" """Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self): def __enter__(self):
self.name = tempfile.mkdtemp() self.name = tempfile.mkdtemp()
return self.name return self.name
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name) shutil.rmtree(self.name)
else: else:
import pickle import pickle
TemporaryDirectory = tempfile.TemporaryDirectory TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str unicode = str
class CommonTestCases: class CommonTestCases:
class CommonTokenizerTester(unittest.TestCase): class CommonTokenizerTester(unittest.TestCase):
tokenizer_class = None tokenizer_class = None
...@@ -57,17 +61,23 @@ class CommonTestCases: ...@@ -57,17 +61,23 @@ class CommonTestCases:
def test_tokenizers_common_properties(self): def test_tokenizers_common_properties(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
attributes_list = ["bos_token", "eos_token", "unk_token", "sep_token", attributes_list = [
"pad_token", "cls_token", "mask_token"] "bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
]
for attr in attributes_list: for attr in attributes_list:
self.assertTrue(hasattr(tokenizer, attr)) self.assertTrue(hasattr(tokenizer, attr))
self.assertTrue(hasattr(tokenizer, attr + "_id")) self.assertTrue(hasattr(tokenizer, attr + "_id"))
self.assertTrue(hasattr(tokenizer, "additional_special_tokens")) self.assertTrue(hasattr(tokenizer, "additional_special_tokens"))
self.assertTrue(hasattr(tokenizer, 'additional_special_tokens_ids')) self.assertTrue(hasattr(tokenizer, "additional_special_tokens_ids"))
attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder", attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder", "added_tokens_decoder"]
"added_tokens_decoder"]
for attr in attributes_list: for attr in attributes_list:
self.assertTrue(hasattr(tokenizer, attr)) self.assertTrue(hasattr(tokenizer, attr))
...@@ -79,13 +89,13 @@ class CommonTestCases: ...@@ -79,13 +89,13 @@ class CommonTestCases:
# Now let's start the test # Now let's start the test
tokenizer = self.get_tokenizer(max_len=42) tokenizer = self.get_tokenizer(max_len=42)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running", add_special_tokens=False) before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running", add_special_tokens=False) after_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
self.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42) self.assertEqual(tokenizer.max_len, 42)
...@@ -96,12 +106,12 @@ class CommonTestCases: ...@@ -96,12 +106,12 @@ class CommonTestCases:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
self.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
text = u"Munich and Berlin are nice cities" text = "Munich and Berlin are nice cities"
subwords = tokenizer.tokenize(text) subwords = tokenizer.tokenize(text)
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, u"tokenizer.bin") filename = os.path.join(tmpdirname, "tokenizer.bin")
with open(filename, "wb") as handle: with open(filename, "wb") as handle:
pickle.dump(tokenizer, handle) pickle.dump(tokenizer, handle)
...@@ -122,7 +132,7 @@ class CommonTestCases: ...@@ -122,7 +132,7 @@ class CommonTestCases:
toks0 = tokenizer.tokenize(text) # toks before adding new_toks toks0 = tokenizer.tokenize(text) # toks before adding new_toks
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", 'AAAAA BBBBBB', 'CCCCCCCCCDDDDDDDD'] new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", "AAAAA BBBBBB", "CCCCCCCCCDDDDDDDD"]
added = tokenizer.add_tokens(new_toks) added = tokenizer.add_tokens(new_toks)
self.assertEqual(added, 2) self.assertEqual(added, 2)
...@@ -178,8 +188,7 @@ class CommonTestCases: ...@@ -178,8 +188,7 @@ class CommonTestCases:
self.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
'pad_token': "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2) added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer) all_size_3 = len(tokenizer)
...@@ -189,8 +198,9 @@ class CommonTestCases: ...@@ -189,8 +198,9 @@ class CommonTestCases:
self.assertEqual(added_toks_2, len(new_toks_2)) self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", tokens = tokenizer.encode(
add_special_tokens=False) ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
)
out_string = tokenizer.decode(tokens) out_string = tokenizer.decode(tokens)
self.assertGreaterEqual(len(tokens), 6) self.assertGreaterEqual(len(tokens), 6)
...@@ -242,7 +252,7 @@ class CommonTestCases: ...@@ -242,7 +252,7 @@ class CommonTestCases:
def test_encode_decode_with_spaces(self): def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] new_toks = ["[ABC]", "[DEF]", "GHI IHG"]
tokenizer.add_tokens(new_toks) tokenizer.add_tokens(new_toks)
input = "[ABC] [DEF] [ABC] GHI IHG [DEF]" input = "[ABC] [DEF] [ABC] GHI IHG [DEF]"
encoded = tokenizer.encode(input, add_special_tokens=False) encoded = tokenizer.encode(input, add_special_tokens=False)
...@@ -264,7 +274,7 @@ class CommonTestCases: ...@@ -264,7 +274,7 @@ class CommonTestCases:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
if tokenizer.build_inputs_with_special_tokens.__qualname__.split('.')[0] != "PreTrainedTokenizer": if tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer":
seq_0 = "Test this method." seq_0 = "Test this method."
seq_1 = "With these inputs." seq_1 = "With these inputs."
information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True) information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True)
...@@ -293,17 +303,19 @@ class CommonTestCases: ...@@ -293,17 +303,19 @@ class CommonTestCases:
sequence = tokenizer.encode(seq_0, add_special_tokens=False) sequence = tokenizer.encode(seq_0, add_special_tokens=False)
num_added_tokens = tokenizer.num_added_tokens() num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(seq_0, information = tokenizer.encode_plus(
max_length=total_length - 2, seq_0,
add_special_tokens=True, max_length=total_length - 2,
stride=stride, add_special_tokens=True,
return_overflowing_tokens=True) stride=stride,
return_overflowing_tokens=True,
)
truncated_sequence = information["input_ids"] truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]
self.assertEqual(len(overflowing_tokens), 2 + stride) self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, sequence[-(2 + stride):]) self.assertEqual(overflowing_tokens, sequence[-(2 + stride) :])
self.assertEqual(len(truncated_sequence), total_length - 2) self.assertEqual(len(truncated_sequence), total_length - 2)
self.assertEqual(truncated_sequence, tokenizer.build_inputs_with_special_tokens(sequence[:-2])) self.assertEqual(truncated_sequence, tokenizer.build_inputs_with_special_tokens(sequence[:-2]))
...@@ -320,24 +332,35 @@ class CommonTestCases: ...@@ -320,24 +332,35 @@ class CommonTestCases:
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True) sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.build_inputs_with_special_tokens( truncated_second_sequence = tokenizer.build_inputs_with_special_tokens(
tokenizer.encode(seq_0, add_special_tokens=False), tokenizer.encode(seq_0, add_special_tokens=False),
tokenizer.encode(seq_1, add_special_tokens=False)[:-2] tokenizer.encode(seq_1, add_special_tokens=False)[:-2],
) )
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True, information = tokenizer.encode_plus(
stride=stride, truncation_strategy='only_second', seq_0,
return_overflowing_tokens=True) seq_1,
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride, add_special_tokens=True,
truncation_strategy='only_first', stride=stride,
return_overflowing_tokens=True) truncation_strategy="only_second",
return_overflowing_tokens=True,
)
information_first_truncated = tokenizer.encode_plus(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=True,
stride=stride,
truncation_strategy="only_first",
return_overflowing_tokens=True,
)
truncated_sequence = information["input_ids"] truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"] overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
self.assertEqual(len(overflowing_tokens), 2 + stride) self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, sequence_1_no_special_tokens[-(2 + stride):]) self.assertEqual(overflowing_tokens, sequence_1_no_special_tokens[-(2 + stride) :])
self.assertEqual(overflowing_tokens_first_truncated, sequence_0_no_special_tokens[-(2 + stride):]) self.assertEqual(overflowing_tokens_first_truncated, sequence_0_no_special_tokens[-(2 + stride) :])
self.assertEqual(len(truncated_sequence), len(sequence) - 2) self.assertEqual(len(truncated_sequence), len(sequence) - 2)
self.assertEqual(truncated_sequence, truncated_second_sequence) self.assertEqual(truncated_sequence, truncated_second_sequence)
...@@ -361,37 +384,47 @@ class CommonTestCases: ...@@ -361,37 +384,47 @@ class CommonTestCases:
# Testing single inputs # Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True, return_special_tokens_mask=True) encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, add_special_tokens=True, return_special_tokens_mask=True
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)] filtered_sequence = [
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
]
filtered_sequence = [x for x in filtered_sequence if x is not None] filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence) self.assertEqual(encoded_sequence, filtered_sequence)
# Testing inputs pairs # Testing inputs pairs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(sequence_1, encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(
add_special_tokens=False) sequence_1, add_special_tokens=False
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True, )
return_special_tokens_mask=True) encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, sequence_1, add_special_tokens=True, return_special_tokens_mask=True
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)] filtered_sequence = [
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
]
filtered_sequence = [x for x in filtered_sequence if x is not None] filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence) self.assertEqual(encoded_sequence, filtered_sequence)
# Testing with already existing special tokens # Testing with already existing special tokens
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id: if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'}) tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, encoded_sequence_dict = tokenizer.encode_plus(
add_special_tokens=True, sequence_0, add_special_tokens=True, return_special_tokens_mask=True
return_special_tokens_mask=True) )
encoded_sequence_w_special = encoded_sequence_dict["input_ids"] encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"] special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True) special_tokens_mask = tokenizer.get_special_tokens_mask(
encoded_sequence_w_special, already_has_special_tokens=True
)
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
self.assertEqual(special_tokens_mask_orig, special_tokens_mask) self.assertEqual(special_tokens_mask_orig, special_tokens_mask)
...@@ -406,7 +439,9 @@ class CommonTestCases: ...@@ -406,7 +439,9 @@ class CommonTestCases:
tokenizer.padding_side = "right" tokenizer.padding_side = "right"
encoded_sequence = tokenizer.encode(sequence) encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence) sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True) padded_sequence = tokenizer.encode(
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
)
padded_sequence_length = len(padded_sequence) padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
...@@ -415,7 +450,9 @@ class CommonTestCases: ...@@ -415,7 +450,9 @@ class CommonTestCases:
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
encoded_sequence = tokenizer.encode(sequence) encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence) sequence_length = len(encoded_sequence)
padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True) padded_sequence = tokenizer.encode(
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
)
padded_sequence_length = len(padded_sequence) padded_sequence_length = len(padded_sequence)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + encoded_sequence == padded_sequence assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
...@@ -446,38 +483,48 @@ class CommonTestCases: ...@@ -446,38 +483,48 @@ class CommonTestCases:
token_type_padding_idx = tokenizer.pad_token_type_id token_type_padding_idx = tokenizer.pad_token_type_id
encoded_sequence = tokenizer.encode_plus(sequence, return_special_tokens_mask=True) encoded_sequence = tokenizer.encode_plus(sequence, return_special_tokens_mask=True)
input_ids = encoded_sequence['input_ids'] input_ids = encoded_sequence["input_ids"]
token_type_ids = encoded_sequence['token_type_ids'] token_type_ids = encoded_sequence["token_type_ids"]
attention_mask = encoded_sequence['attention_mask'] attention_mask = encoded_sequence["attention_mask"]
special_tokens_mask = encoded_sequence['special_tokens_mask'] special_tokens_mask = encoded_sequence["special_tokens_mask"]
sequence_length = len(input_ids) sequence_length = len(input_ids)
# Test right padding # Test right padding
tokenizer.padding_side = "right" tokenizer.padding_side = "right"
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True) padded_sequence = tokenizer.encode_plus(
padded_input_ids = padded_sequence['input_ids'] sequence,
padded_token_type_ids = padded_sequence['token_type_ids'] max_length=sequence_length + padding_size,
padded_attention_mask = padded_sequence['attention_mask'] pad_to_max_length=True,
padded_special_tokens_mask = padded_sequence['special_tokens_mask'] return_special_tokens_mask=True,
)
padded_input_ids = padded_sequence["input_ids"]
padded_token_type_ids = padded_sequence["token_type_ids"]
padded_attention_mask = padded_sequence["attention_mask"]
padded_special_tokens_mask = padded_sequence["special_tokens_mask"]
padded_sequence_length = len(padded_input_ids) padded_sequence_length = len(padded_input_ids)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert input_ids + [padding_idx] * padding_size == padded_input_ids assert input_ids + [padding_idx] * padding_size == padded_input_ids
assert token_type_ids + [token_type_padding_idx] * padding_size == padded_token_type_ids assert token_type_ids + [token_type_padding_idx] * padding_size == padded_token_type_ids
assert attention_mask + [0] * padding_size == padded_attention_mask assert attention_mask + [0] * padding_size == padded_attention_mask
assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask assert special_tokens_mask + [1] * padding_size == padded_special_tokens_mask
# Test left padding # Test left padding
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
padded_sequence = tokenizer.encode_plus(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True, return_special_tokens_mask=True) padded_sequence = tokenizer.encode_plus(
padded_input_ids = padded_sequence['input_ids'] sequence,
padded_token_type_ids = padded_sequence['token_type_ids'] max_length=sequence_length + padding_size,
padded_attention_mask = padded_sequence['attention_mask'] pad_to_max_length=True,
padded_special_tokens_mask = padded_sequence['special_tokens_mask'] return_special_tokens_mask=True,
)
padded_input_ids = padded_sequence["input_ids"]
padded_token_type_ids = padded_sequence["token_type_ids"]
padded_attention_mask = padded_sequence["attention_mask"]
padded_special_tokens_mask = padded_sequence["special_tokens_mask"]
padded_sequence_length = len(padded_input_ids) padded_sequence_length = len(padded_input_ids)
assert sequence_length + padding_size == padded_sequence_length assert sequence_length + padding_size == padded_sequence_length
assert [padding_idx] * padding_size + input_ids == padded_input_ids assert [padding_idx] * padding_size + input_ids == padded_input_ids
assert [token_type_padding_idx] * padding_size + token_type_ids == padded_token_type_ids assert [token_type_padding_idx] * padding_size + token_type_ids == padded_token_type_ids
assert [0] * padding_size + attention_mask == padded_attention_mask assert [0] * padding_size + attention_mask == padded_attention_mask
assert [1] * padding_size + special_tokens_mask == padded_special_tokens_mask assert [1] * padding_size + special_tokens_mask == padded_special_tokens_mask
\ No newline at end of file
...@@ -37,45 +37,53 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -37,45 +37,53 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
super(TransfoXLTokenizationTest, self).setUp() super(TransfoXLTokenizationTest, self).setUp()
vocab_tokens = [ vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "<unk>",
"running", ",", "low", "l", "[CLS]",
"[SEP]",
"want",
"unwanted",
"wa",
"un",
"running",
",",
"low",
"l",
] ]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
kwargs['lower_case'] = True kwargs["lower_case"] = True
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"<unk> UNwanted , running" input_text = "<unk> UNwanted , running"
output_text = u"<unk> unwanted, running" output_text = "<unk> unwanted, running"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running") tokens = tokenizer.tokenize("<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"]) self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self): def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True) tokenizer = TransfoXLTokenizer(lower_case=True)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["hello", "!", "how", "are", "you", "?"]
["hello", "!", "how", "are", "you", "?"]) )
def test_full_tokenizer_no_lower(self): def test_full_tokenizer_no_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=False) tokenizer = TransfoXLTokenizer(lower_case=False)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
["HeLLo", "!", "how", "Are", "yoU", "?"]) )
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -24,8 +24,8 @@ from transformers.tokenization_gpt2 import GPT2Tokenizer ...@@ -24,8 +24,8 @@ from transformers.tokenization_gpt2 import GPT2Tokenizer
from .utils import slow from .utils import slow
class TokenizerUtilsTest(unittest.TestCase):
class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class): def check_tokenizer_from_pretrained(self, tokenizer_class):
s3_models = list(tokenizer_class.max_model_input_sizes.keys()) s3_models = list(tokenizer_class.max_model_input_sizes.keys())
for model_name in s3_models[:1]: for model_name in s3_models[:1]:
...@@ -46,5 +46,6 @@ class TokenizerUtilsTest(unittest.TestCase): ...@@ -46,5 +46,6 @@ class TokenizerUtilsTest(unittest.TestCase):
def test_pretrained_tokenizers(self): def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer) self.check_tokenizer_from_pretrained(GPT2Tokenizer)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -23,6 +23,7 @@ from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES ...@@ -23,6 +23,7 @@ from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .utils import slow from .utils import slow
class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = XLMTokenizer tokenizer_class = XLMTokenizer
...@@ -31,15 +32,34 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -31,15 +32,34 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
super(XLMTokenizationTest, self).setUp() super(XLMTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = [
"w</w>", "r</w>", "t</w>", "l",
"lo", "low", "er</w>", "o",
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"] "w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"w</w>",
"r</w>",
"t</w>",
"lo",
"low",
"er</w>",
"low</w>",
"lowest</w>",
"newer</w>",
"wider</w>",
"<unk>",
]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""] merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w") as fp: with open(self.vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens)) fp.write(json.dumps(vocab_tokens))
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
...@@ -49,8 +69,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -49,8 +69,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = "lower newer"
output_text = u"lower newer" output_text = "lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
...@@ -64,8 +84,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -64,8 +84,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens = tokens + ["<unk>"] input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20] input_bpe_tokens = [14, 15, 20]
self.assertListEqual( self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
...@@ -80,5 +99,6 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -80,5 +99,6 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert encoded_sentence == [1] + text + [1] assert encoded_sentence == [1] + text + [1]
assert encoded_pair == [1] + text + [1] + text_2 + [1] assert encoded_pair == [1] + text + [1] + text_2 + [1]
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment