Commit f31154cb authored by thomwolf's avatar thomwolf
Browse files

Merge branch 'xlnet'

parents 78462aad 1b35d05d
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import pytest
from pytorch_transformers import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel)
from .modeling_common_test import CommonTestCases, ConfigTester
class GPT2ModelTest(unittest.TestCase):
def test_config(self):
config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37)
config_tester.run_common_tests()
def test_model(self):
model_tester = CommonTestCases.GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model,
lm_head_model_class=GPT2LMHeadModel,
double_head_model_class=GPT2DoubleHeadsModel)
model_tester.run_common_tests(test_presents=True)
@pytest.mark.slow
def test_pretrained(self):
model_tester = CommonTestCases.GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model,
lm_head_model_class=GPT2LMHeadModel,
double_head_model_class=GPT2DoubleHeadsModel)
model_tester.run_slow_tests()
if __name__ == "__main__":
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import pytest
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .modeling_common_test import CommonTestCases, ConfigTester
class OpenAIModelTest(unittest.TestCase):
def test_config(self):
config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37)
config_tester.run_common_tests()
def test_model(self):
model_tester = CommonTestCases.GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel,
lm_head_model_class=OpenAIGPTLMHeadModel,
double_head_model_class=OpenAIGPTDoubleHeadsModel)
model_tester.run_common_tests(test_presents=False)
@pytest.mark.slow
def test_pretrained(self):
model_tester = CommonTestCases.GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel,
lm_head_model_class=OpenAIGPTLMHeadModel,
double_head_model_class=OpenAIGPTDoubleHeadsModel)
model_tester.run_slow_tests()
if __name__ == "__main__":
unittest.main()
...@@ -25,10 +25,18 @@ import pytest ...@@ -25,10 +25,18 @@ import pytest
import torch import torch
from pytorch_pretrained_bert import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_pretrained_bert.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import ConfigTester, CommonTestCases, ids_tensor
class TransfoXLModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel)
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
class TransfoXLModelTest(unittest.TestCase):
class TransfoXLModelTester(object): class TransfoXLModelTester(object):
def __init__(self, def __init__(self,
...@@ -41,54 +49,56 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -41,54 +49,56 @@ class TransfoXLModelTest(unittest.TestCase):
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
cutoffs=[10, 50, 80], cutoffs=[10, 50, 80],
d_model=32, hidden_size=32,
d_embed=32, d_embed=32,
n_head=4, num_attention_heads=4,
d_head=8, d_head=8,
d_inner=128, d_inner=128,
div_val=2, div_val=2,
n_layer=5, num_hidden_layers=5,
scope=None, scope=None,
seed=1): seed=1,
):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
self.mem_len = mem_len self.mem_len = mem_len
self.key_len = seq_length + mem_len
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.is_training = is_training self.is_training = is_training
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.cutoffs = cutoffs self.cutoffs = cutoffs
self.d_model = d_model self.hidden_size = hidden_size
self.d_embed = d_embed self.d_embed = d_embed
self.n_head = n_head self.num_attention_heads = num_attention_heads
self.d_head = d_head self.d_head = d_head
self.d_inner = d_inner self.d_inner = d_inner
self.div_val = div_val self.div_val = div_val
self.n_layer = n_layer self.num_hidden_layers = num_hidden_layers
self.scope = scope self.scope = scope
self.seed = seed self.seed = seed
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
lm_labels = None lm_labels = None
if self.use_labels: if self.use_labels:
lm_labels = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = TransfoXLConfig( config = TransfoXLConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
mem_len=self.mem_len, mem_len=self.mem_len,
clamp_len=self.clamp_len, clamp_len=self.clamp_len,
cutoffs=self.cutoffs, cutoffs=self.cutoffs,
d_model=self.d_model, d_model=self.hidden_size,
d_embed=self.d_embed, d_embed=self.d_embed,
n_head=self.n_head, n_head=self.num_attention_heads,
d_head=self.d_head, d_head=self.d_head,
d_inner=self.d_inner, d_inner=self.d_inner,
div_val=self.div_val, div_val=self.div_val,
n_layer=self.n_layer) n_layer=self.num_hidden_layers)
return (config, input_ids_1, input_ids_2, lm_labels) return (config, input_ids_1, input_ids_2, lm_labels)
...@@ -113,37 +123,34 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -113,37 +123,34 @@ class TransfoXLModelTest(unittest.TestCase):
def check_transfo_xl_model_output(self, result): def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states_1"].size()), list(result["hidden_states_1"].size()),
[self.batch_size, self.seq_length, self.d_model]) [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states_2"].size()), list(result["hidden_states_2"].size()),
[self.batch_size, self.seq_length, self.d_model]) [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels): def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
model = TransfoXLLMHeadModel(config) model = TransfoXLLMHeadModel(config)
model.eval() model.eval()
loss_1, mems_1a = model(input_ids_1, target=lm_labels) lm_logits_1, mems_1 = model(input_ids_1)
lm_logits_1, mems_1b = model(input_ids_1) loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
loss_2, mems_2a = model(input_ids_2, target=lm_labels, mems=mems_1a) loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b)
outputs = { outputs = {
"loss_1": loss_1, "loss_1": loss_1,
"mems_1a": mems_1a, "mems_1": mems_1,
"lm_logits_1": lm_logits_1, "lm_logits_1": lm_logits_1,
"mems_1b": mems_1b,
"loss_2": loss_2, "loss_2": loss_2,
"mems_2a": mems_2a, "mems_2": mems_2,
"lm_logits_2": lm_logits_2, "lm_logits_2": lm_logits_2,
"mems_2b": mems_2b,
} }
return outputs return outputs
...@@ -155,14 +162,8 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -155,14 +162,8 @@ class TransfoXLModelTest(unittest.TestCase):
list(result["lm_logits_1"].size()), list(result["lm_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"]))
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss_2"].size()), list(result["loss_2"].size()),
...@@ -171,66 +172,42 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -171,66 +172,42 @@ class TransfoXLModelTest(unittest.TestCase):
list(result["lm_logits_2"].size()), list(result["lm_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2b"]))
def test_default(self):
self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))
def test_config_to_json_string(self):
config = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37)
obj = json.loads(config.to_json_string())
self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_embed"], 37)
def test_config_to_json_file(self):
config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = TransfoXLConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
def run_tester(self, tester): def prepare_config_and_inputs_for_common(self):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids_1}
return config, inputs_dict
tester.set_seed()
output_result = tester.create_transfo_xl_model(*config_and_inputs)
tester.check_transfo_xl_model_output(output_result)
tester.set_seed() def setUp(self):
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs) self.model_tester = TransfoXLModelTest.TransfoXLModelTester(self)
tester.check_transfo_xl_lm_head_output(output_result) self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
@classmethod def test_config(self):
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): self.config_tester.run_common_tests()
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1 def test_transfo_xl_model(self):
for dim in shape: self.model_tester.set_seed()
total_dims *= dim config_and_inputs = self.model_tester.prepare_config_and_inputs()
output_result = self.model_tester.create_transfo_xl_model(*config_and_inputs)
self.model_tester.check_transfo_xl_model_output(output_result)
values = [] def test_transfo_xl_lm_head(self):
for _ in range(total_dims): self.model_tester.set_seed()
values.append(rng.randint(0, vocab_size - 1)) config_and_inputs = self.model_tester.prepare_config_and_inputs()
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
self.model_tester.check_transfo_xl_lm_head_output(output_result)
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() @pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
class XLMModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (XLMModel, XLMWithLMHeadModel,
XLMForQuestionAnswering, XLMForSequenceClassification)
# , XLMForSequenceClassification, XLMForTokenClassification),
class XLMModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_lengths=True,
use_token_type_ids=True,
use_labels=True,
gelu_activation=True,
sinusoidal_embeddings=False,
causal=False,
asm=False,
n_langs=2,
vocab_size=99,
n_special=0,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
summary_type="last",
use_proj=True,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_lengths = use_input_lengths
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.gelu_activation = gelu_activation
self.sinusoidal_embeddings = sinusoidal_embeddings
self.asm = asm
self.n_langs = n_langs
self.vocab_size = vocab_size
self.n_special = n_special
self.summary_type = summary_type
self.causal = causal
self.use_proj = use_proj
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.n_langs = n_langs
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.summary_type = summary_type
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
input_lengths = None
if self.use_input_lengths:
input_lengths = ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2 # small variation of seq_length
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs)
sequence_labels = None
token_labels = None
is_impossible_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
config = XLMConfig(
vocab_size_or_config_json_file=self.vocab_size,
n_special=self.n_special,
emb_dim=self.hidden_size,
n_layers=self.num_hidden_layers,
n_heads=self.num_attention_heads,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
gelu_activation=self.gelu_activation,
sinusoidal_embeddings=self.sinusoidal_embeddings,
asm=self.asm,
causal=self.causal,
n_langs=self.n_langs,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
summary_type=self.summary_type,
use_proj=self.use_proj)
return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMModel(config=config)
model.eval()
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
outputs = model(input_ids, langs=token_type_ids)
outputs = model(input_ids)
sequence_output = outputs[0]
result = {
"sequence_output": sequence_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
def create_and_check_xlm_lm_head(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMWithLMHeadModel(config)
model.eval()
loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForQuestionAnswering(config)
model.eval()
outputs = model(input_ids)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels,
p_mask=input_mask)
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels)
(total_loss,) = outputs
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels)
(total_loss,) = outputs
result = {
"loss": total_loss,
"start_top_log_probs": start_top_log_probs,
"start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["start_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["start_top_index"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["end_top_index"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["cls_logits"].size()),
[self.batch_size])
def create_and_check_xlm_sequence_classif(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForSequenceClassification(config)
model.eval()
(logits,) = model(input_ids)
loss, logits = model(input_ids, labels=sequence_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.type_sequence_label_size])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_lengths,
sequence_labels, token_labels, is_impossible_labels, input_mask) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths}
return config, inputs_dict
def setUp(self):
self.model_tester = XLMModelTest.XLMModelTester(self)
self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_xlm_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_model(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_masked_lm(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_question_answering(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_sequence_classification(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_token_classification(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import ConfigTester, CommonTestCases, ids_tensor
class XLNetModelTest(CommonTestCases.CommonModelTester):
all_model_classes=(XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering)
test_pruning = False
class XLNetModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
mem_len=10,
clamp_len=-1,
reuse_len=15,
is_training=True,
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
hidden_size=32,
num_attention_heads=4,
d_inner=128,
num_hidden_layers=5,
max_position_embeddings=10,
type_sequence_label_size=2,
untie_r=True,
bi_data=False,
same_length=False,
initializer_range=0.05,
seed=1,
type_vocab_size=2,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.mem_len = mem_len
# self.key_len = seq_length + mem_len
self.clamp_len = clamp_len
self.reuse_len = reuse_len
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.cutoffs = cutoffs
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.d_inner = d_inner
self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings
self.bi_data = bi_data
self.untie_r = untie_r
self.same_length = same_length
self.initializer_range = initializer_range
self.seed = seed
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels = None
lm_labels = None
is_impossible_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size,
d_model=self.hidden_size,
n_head=self.num_attention_heads,
d_inner=self.d_inner,
n_layer=self.num_hidden_layers,
untie_r=self.untie_r,
max_position_embeddings=self.max_position_embeddings,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
same_length=self.same_length,
reuse_len=self.reuse_len,
bi_data=self.bi_data,
initializer_range=self.initializer_range,
num_labels=self.type_sequence_label_size)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
def set_seed(self):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetModel(config)
model.eval()
_, _ = model(input_ids_1, input_mask=input_mask)
_, _ = model(input_ids_1, attention_mask=input_mask)
_, _ = model(input_ids_1, token_type_ids=segment_ids)
outputs, mems_1 = model(input_ids_1)
result = {
"mems_1": mems_1,
"outputs": outputs,
}
self.parent.assertListEqual(
list(result["outputs"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetLMHeadModel(config)
model.eval()
loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
result = {
"loss_1": loss_1,
"mems_1": mems_1,
"all_logits_1": all_logits_1,
"loss_2": loss_2,
"mems_2": mems_2,
"all_logits_2": all_logits_2,
}
self.parent.assertListEqual(
list(result["loss_1"].size()),
[])
self.parent.assertListEqual(
list(result["all_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(result["loss_2"].size()),
[])
self.parent.assertListEqual(
list(result["all_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetForQuestionAnswering(config)
model.eval()
outputs = model(input_ids_1)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels,
p_mask=input_mask)
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels,
cls_index=sequence_labels,
is_impossible=is_impossible_labels)
total_loss, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels)
total_loss, mems = outputs
result = {
"loss": total_loss,
"start_top_log_probs": start_top_log_probs,
"start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits,
"mems": mems,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["start_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["start_top_index"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["end_top_index"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["cls_logits"].size()),
[self.batch_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetForSequenceClassification(config)
model.eval()
logits, mems_1 = model(input_ids_1)
loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
result = {
"loss": loss,
"mems_1": mems_1,
"logits": logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.type_sequence_label_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels,
sequence_labels, is_impossible_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids_1}
return config, inputs_dict
def setUp(self):
self.model_tester = XLNetModelTest.XLNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_xlnet_base_model(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_lm_head(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_lm_head(*config_and_inputs)
def test_xlnet_sequence_classif(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
def test_xlnet_qa(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()
...@@ -20,13 +20,19 @@ import unittest ...@@ -20,13 +20,19 @@ import unittest
import torch import torch
from pytorch_pretrained_bert import BertAdam from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
from pytorch_pretrained_bert import OpenAIAdam WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupConstantSchedule, \
WarmupCosineWithWarmupRestartsSchedule, WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule
import numpy as np import numpy as np
def unwrap_schedule(scheduler, num_steps=10):
lrs = []
for _ in range(num_steps):
scheduler.step()
lrs.append(scheduler.get_lr())
return lrs
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
def assertListAlmostEqual(self, list1, list2, tol): def assertListAlmostEqual(self, list1, list2, tol):
...@@ -34,14 +40,12 @@ class OptimizationTest(unittest.TestCase): ...@@ -34,14 +40,12 @@ class OptimizationTest(unittest.TestCase):
for a, b in zip(list1, list2): for a, b in zip(list1, list2):
self.assertAlmostEqual(a, b, delta=tol) self.assertAlmostEqual(a, b, delta=tol)
def test_adam(self): def test_adam_w(self):
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
target = torch.tensor([0.4, 0.2, -0.5]) target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
# No warmup, constant schedule, no gradient clipping # No warmup, constant schedule, no gradient clipping
optimizer = BertAdam(params=[w], lr=2e-1, optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
weight_decay=0.0,
max_grad_norm=-1)
for _ in range(100): for _ in range(100):
loss = criterion(w, target) loss = criterion(w, target)
loss.backward() loss.backward()
...@@ -52,39 +56,49 @@ class OptimizationTest(unittest.TestCase): ...@@ -52,39 +56,49 @@ class OptimizationTest(unittest.TestCase):
class ScheduleInitTest(unittest.TestCase): class ScheduleInitTest(unittest.TestCase):
def test_bert_sched_init(self): m = torch.nn.Linear(50, 50)
m = torch.nn.Linear(50, 50) optimizer = AdamW(m.parameters(), lr=10.)
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None) num_steps = 10
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none") def assertListAlmostEqual(self, list1, list2, tol):
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) self.assertEqual(len(list1), len(list2))
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000) for a, b in zip(list1, list2):
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule)) self.assertAlmostEqual(a, b, delta=tol)
# shouldn't fail
def test_constant_scheduler(self):
def test_openai_sched_init(self): scheduler = ConstantLRSchedule(self.optimizer)
m = torch.nn.Linear(50, 50) lrs = unwrap_schedule(scheduler, self.num_steps)
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None) expected_learning_rates = [10.] * self.num_steps
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) self.assertEqual(len(lrs[0]), 1)
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none") self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000) def test_warmup_constant_scheduler(self):
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule)) scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
# shouldn't fail lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
self.assertEqual(len(lrs[0]), 1)
class WarmupCosineWithRestartsTest(unittest.TestCase): self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
def test_it(self):
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5) def test_warmup_linear_scheduler(self):
x = np.arange(0, 1000) scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10)
y = [m.get_lr(xe) for xe in x] lrs = unwrap_schedule(scheduler, self.num_steps)
y = np.asarray(y) expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0]
expected_zeros = y[[0, 200, 400, 600, 800]] self.assertEqual(len(lrs[0]), 1)
print(expected_zeros) self.assertListEqual([l[0] for l in lrs], expected_learning_rates)
expected_ones = y[[50, 250, 450, 650, 850]]
print(expected_ones) def test_warmup_cosine_scheduler(self):
self.assertTrue(np.allclose(expected_ones, 1)) scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10)
self.assertTrue(np.allclose(expected_zeros, 0)) lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
def test_warmup_cosine_hard_restart_scheduler(self):
scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10)
lrs = unwrap_schedule(scheduler, self.num_steps)
expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0]
self.assertEqual(len(lrs[0]), 1)
self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,54 +17,37 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,54 +17,37 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from io import open from io import open
import shutil
import pytest
from pytorch_pretrained_bert.tokenization import (BasicTokenizer, from pytorch_transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer, BertTokenizer,
WordpieceTokenizer, WordpieceTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
_is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP) _is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class TokenizationTest(unittest.TestCase): class TokenizationTest(unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", "," "##ing", ",", "low", "lowest",
] ]
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: with TemporaryDirectory() as tmpdirname:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name input_text = u"UNwant\u00E9d,running"
output_text = u"unwanted, running"
tokenizer = BertTokenizer(vocab_file) create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokenizer = BertTokenizer(vocab_file)
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual( tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
...@@ -97,7 +80,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -97,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
vocab = {} vocab = {}
for (i, token) in enumerate(vocab_tokens): for (i, token) in enumerate(vocab_tokens):
vocab[token] = i vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab) tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
self.assertListEqual(tokenizer.tokenize(""), []) self.assertListEqual(tokenizer.tokenize(""), [])
......
...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class GPT2TokenizationTest(unittest.TestCase): class GPT2TokenizationTest(unittest.TestCase):
...@@ -29,49 +28,35 @@ class GPT2TokenizationTest(unittest.TestCase): ...@@ -29,49 +28,35 @@ class GPT2TokenizationTest(unittest.TestCase):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er", "lo", "low", "er",
"low", "lowest", "newer", "wider"] "low", "lowest", "newer", "wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""] merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: special_tokens_map = {"unk_token": "<unk>"}
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"]) with TemporaryDirectory() as tmpdirname:
os.remove(vocab_file) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
os.remove(merges_file) merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
text = "lower" input_text = u"lower newer"
bpe_tokens = ["low", "er"] output_text = u"lower<unk>newer"
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"] create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
input_bpe_tokens = [13, 12, 16]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/") tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
tokenizer_2 = GPT2Tokenizer.from_pretrained("/tmp/") text = "lower"
os.remove(vocab_file) bpe_tokens = ["low", "er"]
os.remove(merges_file) tokens = tokenizer.tokenize(text)
os.remove(special_tokens_file) self.assertListEqual(tokens, bpe_tokens)
self.assertListEqual( input_tokens = tokens + [tokenizer.unk_token]
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks, input_bpe_tokens = [13, 12, 17]
tokenizer.special_tokens, tokenizer.special_tokens_decoder], self.assertListEqual(
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
# @pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,10 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,10 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class OpenAIGPTTokenizationTest(unittest.TestCase): class OpenAIGPTTokenizationTest(unittest.TestCase):
...@@ -30,49 +30,34 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -30,49 +30,34 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>", "w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>", "lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"] "low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""] merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"]) with TemporaryDirectory() as tmpdirname:
os.remove(vocab_file) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
os.remove(merges_file) merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
text = "lower" input_text = u"lower newer"
bpe_tokens = ["low", "er</w>"] output_text = u"lower newer"
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"] create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/") tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
tokenizer_2 = OpenAIGPTTokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
os.remove(special_tokens_file)
self.assertListEqual( text = "lower"
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks, bpe_tokens = ["low", "er</w>"]
tokenizer.special_tokens, tokenizer.special_tokens_decoder], tokens = tokenizer.tokenize(text)
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, self.assertListEqual(tokens, bpe_tokens)
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
@pytest.mark.slow input_tokens = tokens + ["<unk>"]
def test_tokenizer_from_pretrained(self): input_bpe_tokens = [14, 15, 20]
cache_dir = "/tmp/pytorch_pretrained_bert_test/" self.assertListEqual(
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__': if __name__ == '__main__':
......
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import sys
from io import open
import tempfile
import shutil
if sys.version_info[0] == 2:
import cPickle as pickle
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
import pickle
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
tester.assertListEqual(before_tokens, after_tokens)
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
tester.assertIsNotNone(tokenizer)
text = u"Munich and Berlin are nice cities"
subwords = tokenizer.tokenize(text)
with TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, u"tokenizer.bin")
pickle.dump(tokenizer, open(filename, "wb"))
tokenizer_new = pickle.load(open(filename, "rb"))
subwords_loaded = tokenizer_new.tokenize(text)
tester.assertListEqual(subwords, subwords_loaded)
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
tester.assertNotEqual(vocab_size, 0)
tester.assertEqual(vocab_size, all_size)
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
tester.assertNotEqual(vocab_size_2, 0)
tester.assertEqual(vocab_size, vocab_size_2)
tester.assertEqual(added_toks, len(new_toks))
tester.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
tester.assertGreaterEqual(len(tokens), 4)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
'pad_token': "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
tester.assertNotEqual(vocab_size_3, 0)
tester.assertEqual(vocab_size, vocab_size_3)
tester.assertEqual(added_toks_2, len(new_toks_2))
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
tester.assertGreaterEqual(len(tokens), 6)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[0], tokens[1])
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokens[-3])
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
tokens = tokenizer.tokenize(input_text)
ids = tokenizer.convert_tokens_to_ids(tokens)
ids_2 = tokenizer.encode(input_text)
tester.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
text_2 = tokenizer.decode(ids)
tester.assertEqual(text_2, output_text)
tester.assertNotEqual(len(tokens_2), 0)
tester.assertIsInstance(text_2, (str, unicode))
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
...@@ -17,42 +17,35 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,42 +17,35 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from io import open from io import open
import shutil
import pytest
from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class TransfoXLTokenizationTest(unittest.TestCase): class TransfoXLTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", "," "<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
"running", ",", "low", "l",
] ]
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: with TemporaryDirectory() as tmpdirname:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = vocab_writer.name with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) input_text = u"<unk> UNwanted , running"
tokenizer.build_vocab() output_text = u"<unk> unwanted, running"
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running") create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual( tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running") tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"]) self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self): def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True) tokenizer = TransfoXLTokenizer(lower_case=True)
...@@ -68,13 +61,6 @@ class TransfoXLTokenizationTest(unittest.TestCase): ...@@ -68,13 +61,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
["HeLLo", "!", "how", "Are", "yoU", "?"]) ["HeLLo", "!", "how", "Are", "yoU", "?"])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
tokenizer = TransfoXLTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(tokenizer)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import six
from pytorch_transformers import PreTrainedTokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class):
s3_models = list(tokenizer_class.max_model_input_sizes.keys())
for model_name in s3_models[:1]:
tokenizer = tokenizer_class.from_pretrained(model_name)
self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, tokenizer_class)
self.assertIsInstance(tokenizer, PreTrainedTokenizer)
for special_tok in tokenizer.all_special_tokens:
if six.PY2:
self.assertIsInstance(special_tok, unicode)
else:
self.assertIsInstance(special_tok, str)
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
self.assertIsInstance(special_tok_id, int)
def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
if __name__ == "__main__":
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
import json
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
class XLMTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
with TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
input_text = u"lower newer"
output_text = u"lower newer"
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
tokenizer = XLMTokenizer(vocab_file, merges_file)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__':
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model')
class XLNetTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
input_text = u"This is a test"
output_text = u"This is a test"
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
602, 347, 347, 347, 3, 12, 66,
46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"])
def test_tokenizer_no_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or',
u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
if __name__ == '__main__':
unittest.main()
...@@ -22,26 +22,32 @@ import os ...@@ -22,26 +22,32 @@ import os
import unicodedata import unicodedata
from io import open from io import open
from .file_utils import cached_path from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = { VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", PRETRAINED_VOCAB_FILES_MAP = {
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 'vocab_file':
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", {
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
}
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-uncased': 512, 'bert-base-uncased': 512,
'bert-large-uncased': 512, 'bert-large-uncased': 512,
'bert-base-cased': 512, 'bert-base-cased': 512,
...@@ -56,21 +62,15 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { ...@@ -56,21 +62,15 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad': 512, 'bert-large-cased-whole-word-masking-finetuned-squad': 512,
'bert-base-cased-finetuned-mrpc': 512, 'bert-base-cased-finetuned-mrpc': 512,
} }
VOCAB_NAME = 'vocab.txt'
def load_vocab(vocab_file): def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary.""" """Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict() vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader: with open(vocab_file, "r", encoding="utf-8") as reader:
while True: tokens = reader.read().splitlines()
token = reader.readline() for index, token in enumerate(tokens):
if not token: vocab[token] = index
break index += 1
token = token.strip()
vocab[token] = index
index += 1
return vocab return vocab
...@@ -83,25 +83,48 @@ def whitespace_tokenize(text): ...@@ -83,25 +83,48 @@ def whitespace_tokenize(text):
return tokens return tokens
class BertTokenizer(object): class BertTokenizer(PreTrainedTokenizer):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" r"""
Constructs a BertTokenizer.
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, :class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
minimum of this value (if specified) and the underlying BERT model's sequence length.
never_split: List of tokens which will never be split during tokenization. Only has an effect when
do_wordpiece_only=False
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs):
"""Constructs a BertTokenizer. """Constructs a BertTokenizer.
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file **vocab_file**: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input **do_lower_case**: (`optional`) boolean (default True)
Only has an effect when do_wordpiece_only=False Whether to lower case the input
do_basic_tokenize: Whether to do basic tokenization before wordpiece. Only has an effect when do_basic_tokenize=True
max_len: An artificial maximum length to truncate tokenized sequences to; **do_basic_tokenize**: (`optional`) boolean (default True)
Effective maximum length is always the minimum of this Whether to do basic tokenization before wordpiece.
value (if specified) and the underlying BERT model's **never_split**: (`optional`) list of string
sequence length. List of tokens which will never be split during tokenization.
never_split: List of tokens which will never be split during tokenization. Only has an effect when do_basic_tokenize=True
Only has an effect when do_wordpiece_only=False **tokenize_chinese_chars**: (`optional`) boolean (default True)
Whether to tokenize Chinese characters.
This should likely be desactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
""" """
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
...@@ -111,46 +134,43 @@ class BertTokenizer(object): ...@@ -111,46 +134,43 @@ class BertTokenizer(object):
[(ids, tok) for tok, ids in self.vocab.items()]) [(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize: if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split) never_split=never_split,
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) tokenize_chinese_chars=tokenize_chinese_chars)
self.max_len = max_len if max_len is not None else int(1e12) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
def tokenize(self, text): @property
def vocab_size(self):
return len(self.vocab)
def _tokenize(self, text):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text): for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for sub_token in self.wordpiece_tokenizer.tokenize(token): for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token) split_tokens.append(sub_token)
else: else:
split_tokens = self.wordpiece_tokenizer.tokenize(text) split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def _convert_token_to_id(self, token):
"""Converts a sequence of tokens into ids using the vocab.""" """ Converts a token (str/unicode) in an id using the vocab. """
ids = [] return self.vocab.get(token, self.vocab.get(self.unk_token))
for token in tokens:
ids.append(self.vocab[token]) def _convert_id_to_token(self, index):
if len(ids) > self.max_len: """Converts an index (integer) in a token (string/unicode) using the vocab."""
logger.warning( return self.ids_to_tokens.get(index, self.unk_token)
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this" def convert_tokens_to_string(self, tokens):
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) """ Converts a sequence of tokens (string) in a single string. """
) out_string = ' '.join(tokens).replace(' ##', '').strip()
return ids return out_string
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
index = 0 index = 0
if os.path.isdir(vocab_path): if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME) vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding="utf-8") as writer: with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
...@@ -159,16 +179,13 @@ class BertTokenizer(object): ...@@ -159,16 +179,13 @@ class BertTokenizer(object):
index = token_index index = token_index
writer.write(token + u'\n') writer.write(token + u'\n')
index += 1 index += 1
return vocab_file return (vocab_file,)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" """ Instantiate a BertTokenizer from pre-trained vocabulary files.
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
""" """
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set " logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but " "`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
...@@ -179,58 +196,44 @@ class BertTokenizer(object): ...@@ -179,58 +196,44 @@ class BertTokenizer(object):
"`do_lower_case` to False. We are setting `do_lower_case=True` for you " "`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.") "but you may want to check this behavior.")
kwargs['do_lower_case'] = True kwargs['do_lower_case'] = True
else:
vocab_file = pretrained_model_name_or_path return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer
class BasicTokenizer(object): class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
do_lower_case=True, """ Constructs a BasicTokenizer.
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BasicTokenizer.
Args: Args:
do_lower_case: Whether to lower case the input. **do_lower_case**: Whether to lower case the input.
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes.
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
List of token not to split.
**tokenize_chinese_chars**: (`optional`) boolean (default True)
Whether to tokenize Chinese characters.
This should likely be desactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
""" """
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.never_split = never_split self.never_split = never_split
self.tokenize_chinese_chars = tokenize_chinese_chars
def tokenize(self, text): def tokenize(self, text, never_split=None):
"""Tokenizes a piece of text.""" """ Basic Tokenization of a piece of text.
Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
Args:
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes.
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
List of token not to split.
"""
never_split = self.never_split + (never_split if never_split is not None else [])
text = self._clean_text(text) text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese # This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't # models. This is also applied to the English models now, but it doesn't
...@@ -238,11 +241,12 @@ class BasicTokenizer(object): ...@@ -238,11 +241,12 @@ class BasicTokenizer(object):
# and generally don't have any Chinese data in them (there are Chinese # and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese # characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.). # words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text) if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text) orig_tokens = whitespace_tokenize(text)
split_tokens = [] split_tokens = []
for token in orig_tokens: for token in orig_tokens:
if self.do_lower_case and token not in self.never_split: if self.do_lower_case and token not in never_split:
token = token.lower() token = token.lower()
token = self._run_strip_accents(token) token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token)) split_tokens.extend(self._run_split_on_punc(token))
...@@ -261,9 +265,9 @@ class BasicTokenizer(object): ...@@ -261,9 +265,9 @@ class BasicTokenizer(object):
output.append(char) output.append(char)
return "".join(output) return "".join(output)
def _run_split_on_punc(self, text): def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text.""" """Splits punctuation on a piece of text."""
if text in self.never_split: if never_split is not None and text in never_split:
return [text] return [text]
chars = list(text) chars = list(text)
i = 0 i = 0
...@@ -335,7 +339,7 @@ class BasicTokenizer(object): ...@@ -335,7 +339,7 @@ class BasicTokenizer(object):
class WordpieceTokenizer(object): class WordpieceTokenizer(object):
"""Runs WordPiece tokenization.""" """Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab self.vocab = vocab
self.unk_token = unk_token self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word self.max_input_chars_per_word = max_input_chars_per_word
......
...@@ -31,24 +31,32 @@ except ImportError: ...@@ -31,24 +31,32 @@ except ImportError:
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
from .file_utils import cached_path from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = { VOCAB_FILES_NAMES = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'vocab_file': 'vocab.json',
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 'merges_file': 'merges.txt',
} }
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", PRETRAINED_VOCAB_FILES_MAP = {
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 'vocab_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
},
'merges_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
},
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024, 'gpt2': 1024,
'gpt2-medium': 1024,
} }
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache() @lru_cache()
def bytes_to_unicode(): def bytes_to_unicode():
...@@ -85,71 +93,19 @@ def get_pairs(word): ...@@ -85,71 +93,19 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
class GPT2Tokenizer(object): class GPT2Tokenizer(PreTrainedTokenizer):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE - Byte-level BPE
""" """
@classmethod vocab_files_names = VOCAB_FILES_NAMES
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
""" max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a GPT2Tokenizer from a pre-trained model file.
Download and cache the pre-trained model file if needed. def __init__(self, vocab_file, merges_file, errors='replace',
""" bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs)
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
...@@ -163,25 +119,9 @@ class GPT2Tokenizer(object): ...@@ -163,25 +119,9 @@ class GPT2Tokenizer(object):
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {} @property
self.special_tokens_decoder = {} def vocab_size(self):
self.set_special_tokens(special_tokens) return len(self.encoder)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
if token in self.cache: if token in self.cache:
...@@ -224,7 +164,7 @@ class GPT2Tokenizer(object): ...@@ -224,7 +164,7 @@ class GPT2Tokenizer(object):
self.cache[token] = word self.cache[token] = word
return word return word
def tokenize(self, text): def _tokenize(self, text):
""" Tokenize a string. """ """ Tokenize a string. """
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
...@@ -235,59 +175,29 @@ class GPT2Tokenizer(object): ...@@ -235,59 +175,29 @@ class GPT2Tokenizer(object):
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens return bpe_tokens
def convert_tokens_to_ids(self, tokens): def _convert_token_to_id(self, token):
""" Converts a sequence of tokens into ids using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
ids = [] if token in self.encoder:
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): return self.encoder.get(token)
if tokens in self.special_tokens: return self.encoder.get(self.unk_token)
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text): def _convert_id_to_token(self, index):
return self.convert_tokens_to_ids(self.tokenize(text)) """Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index)
def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True): def convert_tokens_to_string(self, tokens):
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens)) """ Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
if clean_up_tokenization_spaces:
text = text.replace('<unk>', '')
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return text return text
def save_vocabulary(self, vocab_path): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(vocab_path, VOCAB_NAME) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(vocab_path, MERGES_NAME) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
...@@ -303,14 +213,4 @@ class GPT2Tokenizer(object): ...@@ -303,14 +213,4 @@ class GPT2Tokenizer(object):
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(' '.join(bpe_tokens) + u'\n')
index += 1 index += 1
index = len(self.encoder) return vocab_file, merge_file
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
...@@ -20,28 +20,32 @@ import json ...@@ -20,28 +20,32 @@ import json
import logging import logging
import os import os
import re import re
import sys
from io import open from io import open
from tqdm import tqdm from .tokenization_utils import PreTrainedTokenizer
from .tokenization_bert import BasicTokenizer
from .file_utils import cached_path
from .tokenization import BasicTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = { VOCAB_FILES_NAMES = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
} }
PRETRAINED_MERGES_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
},
'merges_file':
{
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
},
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'openai-gpt': 512, 'openai-gpt': 512,
} }
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
def get_pairs(word): def get_pairs(word):
""" """
...@@ -70,73 +74,19 @@ def text_standardize(text): ...@@ -70,73 +74,19 @@ def text_standardize(text):
text = re.sub(r'[^\S\n]+', ' ', text) text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip() return text.strip()
class OpenAIGPTTokenizer(object): class OpenAIGPTTokenizer(PreTrainedTokenizer):
""" """
BPE tokenizer. Peculiarities: BPE tokenizer. Peculiarities:
- lower case all inputs - lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
""" """
@classmethod vocab_files_names = VOCAB_FILES_NAMES
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
""" max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed. def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
""" super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
try: try:
import ftfy import ftfy
import spacy import spacy
...@@ -144,39 +94,19 @@ class OpenAIGPTTokenizer(object): ...@@ -144,39 +94,19 @@ class OpenAIGPTTokenizer(object):
self.fix_text = ftfy.fix_text self.fix_text = ftfy.fix_text
except ImportError: except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True, self.nlp = BasicTokenizer(do_lower_case=True)
never_split=special_tokens if special_tokens is not None else [])
self.fix_text = None self.fix_text = None
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self): @property
return len(self.encoder) + len(self.special_tokens) def vocab_size(self):
return len(self.encoder)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
if self.fix_text is None:
# Using BERT's BasicTokenizer: we can update the tokenizer
self.nlp.never_split = special_tokens
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',) word = tuple(token[:-1]) + (token[-1] + '</w>',)
...@@ -221,7 +151,7 @@ class OpenAIGPTTokenizer(object): ...@@ -221,7 +151,7 @@ class OpenAIGPTTokenizer(object):
self.cache[token] = word self.cache[token] = word
return word return word
def tokenize(self, text): def _tokenize(self, text):
""" Tokenize a string. """ """ Tokenize a string. """
split_tokens = [] split_tokens = []
if self.fix_text is None: if self.fix_text is None:
...@@ -236,60 +166,26 @@ class OpenAIGPTTokenizer(object): ...@@ -236,60 +166,26 @@ class OpenAIGPTTokenizer(object):
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def _convert_token_to_id(self, token):
""" Converts a sequence of tokens into ids using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
ids = [] return self.encoder.get(token, self.encoder.get(self.unk_token))
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text): def _convert_id_to_token(self, index):
return self.convert_tokens_to_ids(self.tokenize(text)) """Converts an id in a token (BPE) using the vocab."""
return self.decoder.get(index, self.unk_token)
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of ids in a string.""" """ Converts a sequence of tokens (string) in a single string. """
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string return out_string
def save_vocabulary(self, vocab_path): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(vocab_path, VOCAB_NAME) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(vocab_path, MERGES_NAME) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
...@@ -305,14 +201,4 @@ class OpenAIGPTTokenizer(object): ...@@ -305,14 +201,4 @@ class OpenAIGPTTokenizer(object):
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(' '.join(bpe_tokens) + u'\n')
index += 1 index += 1
index = len(self.encoder) return vocab_file, merge_file
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
...@@ -25,12 +25,12 @@ import os ...@@ -25,12 +25,12 @@ import os
import sys import sys
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from io import open from io import open
import unicodedata
import torch import torch
import numpy as np import numpy as np
from .file_utils import cached_path from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
...@@ -40,66 +40,43 @@ else: ...@@ -40,66 +40,43 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = { VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin', 'vocab_file': 'vocab.txt'}
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
PRETRAINED_VOCAB_FILES_MAP = {
'pretrained_vocab_file':
{
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'transfo-xl-wt103': None,
} }
VOCAB_NAME = 'vocab.bin'
PRETRAINED_CORPUS_ARCHIVE_MAP = { PRETRAINED_CORPUS_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin", 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
} }
CORPUS_NAME = 'corpus.bin' CORPUS_NAME = 'corpus.bin'
class TransfoXLTokenizer(object): class TransfoXLTokenizer(PreTrainedTokenizer):
""" """
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
""" """
@classmethod vocab_files_names = VOCAB_FILES_NAMES
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
""" max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a TransfoXLTokenizer.
The TransfoXLTokenizer. def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False,
""" delimiter=None, vocab_file=None, pretrained_vocab_file=None,
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: never_split=None, unk_token="<unk>", eos_token="<eos>",
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] additional_special_tokens=["<formula>"], **kwargs):
else: super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,
if os.path.isdir(pretrained_model_name_or_path): additional_special_tokens=additional_special_tokens,
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) **kwargs)
else: if never_split is None:
vocab_file = pretrained_model_name_or_path never_split = self.all_special_tokens
# redirect to the cache, if necessary if special is None:
try: special = []
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
tokenizer = cls(*inputs, **kwargs)
vocab_dict = torch.load(resolved_vocab_file)
for key, value in vocab_dict.items():
tokenizer.__dict__[key] = value
return tokenizer
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
delimiter=None, vocab_file=None, never_split=("<unk>", "<eos>", "<formula>")):
self.counter = Counter() self.counter = Counter()
self.special = special self.special = special
self.min_freq = min_freq self.min_freq = min_freq
...@@ -109,15 +86,25 @@ class TransfoXLTokenizer(object): ...@@ -109,15 +86,25 @@ class TransfoXLTokenizer(object):
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.never_split = never_split self.never_split = never_split
if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used
# in a library like ours, at all.
vocab_dict = torch.load(pretrained_vocab_file)
for key, value in vocab_dict.items():
self.__dict__[key] = value
if vocab_file is not None:
self.build_vocab()
def count_file(self, path, verbose=False, add_eos=False): def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path)) if verbose: logger.info('counting file {} ...'.format(path))
assert os.path.exists(path) assert os.path.exists(path)
sents = [] sents = []
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f): for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0: if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx)) logger.info(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos) symbols = self.tokenize(line, add_eos=add_eos)
self.counter.update(symbols) self.counter.update(symbols)
sents.append(symbols) sents.append(symbols)
...@@ -128,10 +115,10 @@ class TransfoXLTokenizer(object): ...@@ -128,10 +115,10 @@ class TransfoXLTokenizer(object):
""" """
sents : a list of sentences, each a list of tokenized symbols sents : a list of sentences, each a list of tokenized symbols
""" """
if verbose: print('counting {} sents ...'.format(len(sents))) if verbose: logger.info('counting {} sents ...'.format(len(sents)))
for idx, symbols in enumerate(sents): for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0: if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx)) logger.info(' line {}'.format(idx))
self.counter.update(symbols) self.counter.update(symbols)
def _build_from_file(self, vocab_file): def _build_from_file(self, vocab_file):
...@@ -153,17 +140,17 @@ class TransfoXLTokenizer(object): ...@@ -153,17 +140,17 @@ class TransfoXLTokenizer(object):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
index = 0 index = 0
if os.path.isdir(vocab_path): if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME) vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['pretrained_vocab_file'])
torch.save(self.__dict__, vocab_file) torch.save(self.__dict__, vocab_file)
return vocab_file return (vocab_file,)
def build_vocab(self): def build_vocab(self):
if self.vocab_file: if self.vocab_file:
print('building vocab from {}'.format(self.vocab_file)) logger.info('building vocab from {}'.format(self.vocab_file))
self._build_from_file(self.vocab_file) self._build_from_file(self.vocab_file)
print('final vocab size {}'.format(len(self))) logger.info('final vocab size {}'.format(len(self)))
else: else:
print('building vocab with min_freq={}, max_size={}'.format( logger.info('building vocab with min_freq={}, max_size={}'.format(
self.min_freq, self.max_size)) self.min_freq, self.max_size))
self.idx2sym = [] self.idx2sym = []
self.sym2idx = OrderedDict() self.sym2idx = OrderedDict()
...@@ -175,18 +162,18 @@ class TransfoXLTokenizer(object): ...@@ -175,18 +162,18 @@ class TransfoXLTokenizer(object):
if cnt < self.min_freq: break if cnt < self.min_freq: break
self.add_symbol(sym) self.add_symbol(sym)
print('final vocab size {} from {} unique tokens'.format( logger.info('final vocab size {} from {} unique tokens'.format(
len(self), len(self.counter))) len(self), len(self.counter)))
def encode_file(self, path, ordered=False, verbose=False, add_eos=True, def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
add_double_eos=False): add_double_eos=False):
if verbose: print('encoding file {} ...'.format(path)) if verbose: logger.info('encoding file {} ...'.format(path))
assert os.path.exists(path) assert os.path.exists(path)
encoded = [] encoded = []
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f): for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0: if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx)) logger.info(' line {}'.format(idx))
symbols = self.tokenize(line, add_eos=add_eos, symbols = self.tokenize(line, add_eos=add_eos,
add_double_eos=add_double_eos) add_double_eos=add_double_eos)
encoded.append(self.convert_to_tensor(symbols)) encoded.append(self.convert_to_tensor(symbols))
...@@ -197,11 +184,11 @@ class TransfoXLTokenizer(object): ...@@ -197,11 +184,11 @@ class TransfoXLTokenizer(object):
return encoded return encoded
def encode_sents(self, sents, ordered=False, verbose=False): def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: print('encoding {} sents ...'.format(len(sents))) if verbose: logger.info('encoding {} sents ...'.format(len(sents)))
encoded = [] encoded = []
for idx, symbols in enumerate(sents): for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0: if verbose and idx > 0 and idx % 500000 == 0:
print(' line {}'.format(idx)) logger.info(' line {}'.format(idx))
encoded.append(self.convert_to_tensor(symbols)) encoded.append(self.convert_to_tensor(symbols))
if ordered: if ordered:
...@@ -220,15 +207,17 @@ class TransfoXLTokenizer(object): ...@@ -220,15 +207,17 @@ class TransfoXLTokenizer(object):
self.idx2sym.append(sym) self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1 self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx): def _convert_id_to_token(self, idx):
"""Converts an id in a token (BPE) using the vocab."""
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx) assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
return self.idx2sym[idx] return self.idx2sym[idx]
def get_idx(self, sym): def _convert_token_to_id(self, sym):
""" Converts a token (str/unicode) in an id using the vocab. """
if sym in self.sym2idx: if sym in self.sym2idx:
return self.sym2idx[sym] return self.sym2idx[sym]
else: else:
# print('encounter unk {}'.format(sym)) # logger.info('encounter unk {}'.format(sym))
# assert '<eos>' not in sym # assert '<eos>' not in sym
if hasattr(self, 'unk_idx'): if hasattr(self, 'unk_idx'):
return self.sym2idx.get(sym, self.unk_idx) return self.sym2idx.get(sym, self.unk_idx)
...@@ -240,28 +229,19 @@ class TransfoXLTokenizer(object): ...@@ -240,28 +229,19 @@ class TransfoXLTokenizer(object):
else: else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement') raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
def convert_ids_to_tokens(self, indices): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of indices in symbols using the vocab.""" """ Converts a sequence of tokens (string) in a single string. """
return [self.get_sym(idx) for idx in indices] out_string = ' '.join(tokens).strip()
return out_string
def convert_tokens_to_ids(self, symbols):
"""Converts a sequence of symbols into ids using the vocab."""
return [self.get_idx(sym) for sym in symbols]
def convert_to_tensor(self, symbols): def convert_to_tensor(self, symbols):
return torch.LongTensor(self.convert_tokens_to_ids(symbols)) return torch.LongTensor(self.convert_tokens_to_ids(symbols))
def decode(self, indices, exclude=None): @property
"""Converts a sequence of indices in a string.""" def vocab_size(self):
if exclude is None:
return ' '.join([self.get_sym(idx) for idx in indices])
else:
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
def __len__(self):
return len(self.idx2sym) return len(self.idx2sym)
def tokenize(self, line, add_eos=False, add_double_eos=False): def _tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip() line = line.strip()
# convert to lower case # convert to lower case
if self.lower_case: if self.lower_case:
...@@ -472,7 +452,7 @@ class TransfoXLCorpus(object): ...@@ -472,7 +452,7 @@ class TransfoXLCorpus(object):
"We assumed '{}' was a path or url but couldn't find files {} " "We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format( "at this path or url.".format(
pretrained_model_name_or_path, pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path, pretrained_model_name_or_path,
corpus_file)) corpus_file))
return None return None
...@@ -563,14 +543,14 @@ def get_lm_corpus(datadir, dataset): ...@@ -563,14 +543,14 @@ def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, 'cache.pt') fn = os.path.join(datadir, 'cache.pt')
fn_pickle = os.path.join(datadir, 'cache.pkl') fn_pickle = os.path.join(datadir, 'cache.pkl')
if os.path.exists(fn): if os.path.exists(fn):
print('Loading cached dataset...') logger.info('Loading cached dataset...')
corpus = torch.load(fn_pickle) corpus = torch.load(fn_pickle)
elif os.path.exists(fn): elif os.path.exists(fn):
print('Loading cached dataset from pickle...') logger.info('Loading cached dataset from pickle...')
with open(fn, "rb") as fp: with open(fn, "rb") as fp:
corpus = pickle.load(fp) corpus = pickle.load(fp)
else: else:
print('Producing dataset {}...'.format(dataset)) logger.info('Producing dataset {}...'.format(dataset))
kwargs = {} kwargs = {}
if dataset in ['wt103', 'wt2']: if dataset in ['wt103', 'wt2']:
kwargs['special'] = ['<eos>'] kwargs['special'] = ['<eos>']
......
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
import os
import json
import six
from io import open
from .file_utils import cached_path
logger = logging.getLogger(__name__)
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
ADDED_TOKENS_FILE = 'added_tokens.json'
class PreTrainedTokenizer(object):
""" An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary.
Derived class can set up a few special tokens to be used in common scripts and internals:
bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token
additional_special_tokens = []
We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the
specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...).
"""
vocab_files_names = {}
pretrained_vocab_files_map = {}
max_model_input_sizes = {}
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
"pad_token", "cls_token", "mask_token",
"additional_special_tokens"]
@property
def bos_token(self):
if self._bos_token is None:
logger.error("Using bos_token, but it is not set yet.")
return self._bos_token
@property
def eos_token(self):
if self._eos_token is None:
logger.error("Using eos_token, but it is not set yet.")
return self._eos_token
@property
def unk_token(self):
if self._unk_token is None:
logger.error("Using unk_token, but it is not set yet.")
return self._unk_token
@property
def sep_token(self):
if self._sep_token is None:
logger.error("Using sep_token, but it is not set yet.")
return self._sep_token
@property
def pad_token(self):
if self._pad_token is None:
logger.error("Using pad_token, but it is not set yet.")
return self._pad_token
@property
def cls_token(self):
if self._cls_token is None:
logger.error("Using cls_token, but it is not set yet.")
return self._cls_token
@property
def mask_token(self):
if self._mask_token is None:
logger.error("Using mask_token, but it is not set yet.")
return self._mask_token
@property
def additional_special_tokens(self):
if self._additional_special_tokens is None:
logger.error("Using additional_special_tokens, but it is not set yet.")
return self._additional_special_tokens
@bos_token.setter
def bos_token(self, value):
self._bos_token = value
@eos_token.setter
def eos_token(self, value):
self._eos_token = value
@unk_token.setter
def unk_token(self, value):
self._unk_token = value
@sep_token.setter
def sep_token(self, value):
self._sep_token = value
@pad_token.setter
def pad_token(self, value):
self._pad_token = value
@cls_token.setter
def cls_token(self, value):
self._cls_token = value
@mask_token.setter
def mask_token(self, value):
self._mask_token = value
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
def __init__(self, max_len=None, **kwargs):
self._bos_token = None
self._eos_token = None
self._unk_token = None
self._sep_token = None
self._pad_token = None
self._cls_token = None
self._mask_token = None
self._additional_special_tokens = []
self.max_len = max_len if max_len is not None else int(1e12)
self.added_tokens_encoder = {}
self.added_tokens_decoder = {}
for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
setattr(self, key, value)
@classmethod
def from_pretrained(cls, *inputs, **kwargs):
return cls._from_pretrained(*inputs, **kwargs)
@classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
Download and cache the vocabulary files if needed.
"""
s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
if pretrained_model_name_or_path in s3_models:
for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
else:
all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}
all_vocab_files_names.update(cls.vocab_files_names)
for file_id, file_name in all_vocab_files_names.items():
if os.path.isdir(pretrained_model_name_or_path):
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
else:
full_file_name = pretrained_model_name_or_path
if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None
vocab_files[file_id] = full_file_name
# Get files from url, cache, or disk depending on the case
try:
resolved_vocab_files = {}
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in s3_models:
logger.error("Couldn't reach server to download vocabulary.")
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path, ', '.join(s3_models),
pretrained_model_name_or_path, str(vocab_files.keys())))
return None
for file_id, file_path in vocab_files.items():
if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path))
else:
logger.info("loading file {} from cache at {}".format(
file_path, resolved_vocab_files[file_id]))
# Set max length if needed
if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
if max_len is not None and isinstance(max_len, (int, float)):
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Merge resolved_vocab_files arguments in kwargs.
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
for args_name, file_path in resolved_vocab_files.items():
if args_name not in kwargs:
kwargs[args_name] = file_path
if special_tokens_map_file is not None:
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
for key, value in special_tokens_map.items():
if key not in kwargs:
kwargs[key] = value
# Instantiate tokenizer.
tokenizer = cls(*inputs, **kwargs)
# Add supplementary tokens.
if added_tokens_file is not None:
added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8"))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
tokenizer.added_tokens_encoder.update(added_tok_encoder)
tokenizer.added_tokens_decoder.update(added_tok_decoder)
return tokenizer
def save_pretrained(self, save_directory):
""" Save the tokenizer vocabulary files (with added tokens) and the
special-tokens-to-class-attributes-mapping to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
if not os.path.isdir(save_directory):
logger.error("Saving directory ({}) should be a directory".format(save_directory))
return
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
with open(added_tokens_file, 'w', encoding='utf-8') as f:
if self.added_tokens_encoder:
out_str = json.dumps(self.added_tokens_decoder, ensure_ascii=False)
else:
out_str = u"{}"
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory)
return vocab_files + (special_tokens_map_file, added_tokens_file)
def save_vocabulary(self, save_directory):
""" Save the tokenizer vocabulary to a directory. This method doesn't save added tokens
and special token mappings.
Please use `save_pretrained()` to save the full Tokenizer state so that it can be
reloaded using the `from_pretrained(save_directory)` class method.
"""
raise NotImplementedError
def vocab_size(self):
raise NotImplementedError
def __len__(self):
return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens):
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to the added_tokens_encoder with indices starting from
the last index of the current vocabulary.
Returns:
Number of tokens added to the vocabulary which can be used to correspondingly
increase the size of the associated model embedding matrices.
"""
if not new_tokens:
return 0
to_add_tokens = []
for token in new_tokens:
if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
to_add_tokens.append(token)
logger.info("Adding %s to the vocabulary", token)
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
return len(to_add_tokens)
def add_special_tokens(self, special_tokens_dict):
""" Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If the special tokens are not in the vocabulary, they are added
to it and indexed starting from the last index of the current vocabulary.
Returns:
Number of tokens added to the vocabulary which can be used to correspondingly
increase the size of the associated model embedding matrices.
"""
if not special_tokens_dict:
return 0
added_special_tokens = self.add_tokens(special_tokens_dict.values())
for key, value in special_tokens_dict.items():
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value)
return added_special_tokens
def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based
vocabularies (BPE/SentencePieces/WordPieces).
Take care of added tokens.
"""
def split_on_tokens(tok_list, text):
if not text:
return []
if not tok_list:
return self._tokenize(text, **kwargs)
tok = tok_list[0]
split_text = text.split(tok)
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
for sub_text in split_text), [])[:-1]
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
tokenized_text = split_on_tokens(added_tokens, text)
return tokenized_text
def _tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based
vocabularies (BPE/SentencePieces/WordPieces).
Don't take care of added tokens.
"""
raise NotImplementedError
def convert_tokens_to_ids(self, tokens):
""" Converts a single token or a sequence of tokens (str/unicode) in a integer id
(resp.) a sequence of ids, using the vocabulary.
"""
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
return self._convert_token_to_id_with_added_voc(tokens)
ids = []
for token in tokens:
ids.append(self._convert_token_to_id_with_added_voc(token))
if len(ids) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len))
return ids
def _convert_token_to_id_with_added_voc(self, token):
if token in self.added_tokens_encoder:
return self.added_tokens_encoder[token]
return self._convert_token_to_id(token)
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self, text):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
same as self.convert_tokens_to_ids(self.tokenize(text)).
"""
return self.convert_tokens_to_ids(self.tokenize(text))
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
Args:
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
return self.added_tokens_decoder[ids]
else:
return self._convert_id_to_token(ids)
tokens = []
for index in ids:
if index in self.all_special_ids and skip_special_tokens:
continue
if index in self.added_tokens_decoder:
tokens.append(self.added_tokens_decoder[index])
else:
tokens.append(self._convert_id_to_token(index))
return tokens
def _convert_id_to_token(self, index):
raise NotImplementedError
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string.
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
but we often want to remove sub-word tokenization artifacts at the same time.
"""
return ' '.join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self.convert_tokens_to_string(filtered_tokens)
if clean_up_tokenization_spaces:
text = clean_up_tokenization(text)
return text
@property
def special_tokens_map(self):
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
values ('<unk>', '<cls>'...)
"""
set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr)
if attr_value:
set_attr[attr] = attr_value
return set_attr
@property
def all_special_tokens(self):
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
(cls_token, unk_token...).
"""
all_toks = []
set_attr = self.special_tokens_map
for attr_value in set_attr.values():
all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value])
all_toks = list(set(all_toks))
return all_toks
@property
def all_special_ids(self):
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
class attributes (cls_token, unk_token...).
"""
all_toks = self.all_special_tokens
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
return all_ids
def clean_up_tokenization(out_string):
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string
# coding=utf-8
# Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import json
import logging
import os
import re
from io import open
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
},
'merges_file':
{
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt",
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt",
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-mlm-en-2048': 512,
}
def get_pairs(word):
"""
Return set of symbol pairs in a word.
word is represented as tuple of symbols (symbols being variable-length strings)
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def text_standardize(text):
"""
fixes some issues the spacy tokenizer had on books corpus
also does some whitespace standardization
"""
text = text.replace('—', '-')
text = text.replace('–', '-')
text = text.replace('―', '-')
text = text.replace('…', '...')
text = text.replace('´', "'")
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
text = re.sub(r'\s*\n\s*', ' \n ', text)
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
class XLMTokenizer(PreTrainedTokenizer):
"""
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
- lower case all inputs
- uses `SpaCy tokenizer <https://spacy.io/api/tokenizer/>`_ and \
`ftfy <https://ftfy.readthedocs.io/en/latest/>`_ for pre-BPE tokenization if they are installed, \
fallback to BERT's BasicTokenizer if not.
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
(ex: "__classify__") to a vocabulary.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
mask_token="<special1>", additional_special_tokens=["<special0>",
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
"<special6>", "<special7>", "<special8>", "<special9>"], **kwargs):
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
sep_token=sep_token, pad_token=pad_token,
cls_token=cls_token, mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
try:
import ftfy
import spacy
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.fix_text = ftfy.fix_text
except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True)
self.fix_text = None
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1]
merges = [tuple(merge.split()[:2]) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
@property
def vocab_size(self):
return len(self.encoder)
def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',)
if token in self.cache:
return self.cache[token]
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
if word == '\n </w>':
word = '\n</w>'
self.cache[token] = word
return word
def _tokenize(self, text):
""" Tokenize a string. """
split_tokens = []
if self.fix_text is None:
# Using BERT's BasicTokenizer
text = self.nlp.tokenize(text)
for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text)))
for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
return vocab_file, merge_file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment