"vscode:/vscode.git/clone" did not exist on "e7ed7ffdcb66c78d3437ed4c3a63c3640f50f436"
Unverified Commit 3c1b6f59 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Merge branch 'master' into fix_top_k_top_p_filtering

parents a9f24a16 fa735208
...@@ -20,8 +20,14 @@ import unittest ...@@ -20,8 +20,14 @@ import unittest
import shutil import shutil
import pytest import pytest
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification) from transformers import is_torch_available
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
if is_torch_available():
from transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
XLMForSequenceClassification, XLMForQuestionAnsweringSimple)
from transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
...@@ -29,9 +35,9 @@ from .configuration_common_test import ConfigTester ...@@ -29,9 +35,9 @@ from .configuration_common_test import ConfigTester
class XLMModelTest(CommonTestCases.CommonModelTester): class XLMModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (XLMModel, XLMWithLMHeadModel, all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
XLMForQuestionAnswering, XLMForSequenceClassification) XLMForSequenceClassification, XLMForQuestionAnsweringSimple) if is_torch_available() else ()
# , XLMForSequenceClassification, XLMForTokenClassification),
class XLMModelTester(object): class XLMModelTester(object):
...@@ -174,12 +180,36 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -174,12 +180,36 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_xlm_simple_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForQuestionAnsweringSimple(config)
model.eval()
outputs = model(input_ids)
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels)
loss, start_logits, end_logits = outputs
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
model = XLMForQuestionAnswering(config) model = XLMForQuestionAnswering(config)
model.eval() model.eval()
outputs = model(input_ids) outputs = model(input_ids)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
outputs = model(input_ids, start_positions=sequence_labels, outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
...@@ -266,24 +296,25 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -266,24 +296,25 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_model(*config_and_inputs) self.model_tester.create_and_check_xlm_model(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs() def test_xlm_lm_head(self):
# tester.create_and_check_xlm_for_masked_lm(*config_and_inputs) config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs() def test_xlm_simple_qa(self):
# tester.create_and_check_xlm_for_question_answering(*config_and_inputs) config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_simple_qa(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs() def test_xlm_qa(self):
# tester.create_and_check_xlm_for_sequence_classification(*config_and_inputs) config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_qa(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs() def test_xlm_sequence_classif(self):
# tester.create_and_check_xlm_for_token_classification(*config_and_inputs) config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_sequence_classif(*config_and_inputs)
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -23,10 +23,15 @@ import random ...@@ -23,10 +23,15 @@ import random
import shutil import shutil
import pytest import pytest
import torch from transformers import is_torch_available
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available():
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP import torch
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
...@@ -34,7 +39,7 @@ from .configuration_common_test import ConfigTester ...@@ -34,7 +39,7 @@ from .configuration_common_test import ConfigTester
class XLNetModelTest(CommonTestCases.CommonModelTester): class XLNetModelTest(CommonTestCases.CommonModelTester):
all_model_classes=(XLNetModel, XLNetLMHeadModel, all_model_classes=(XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering) XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available() else ()
test_pruning = False test_pruning = False
class XLNetModelTester(object): class XLNetModelTester(object):
...@@ -145,6 +150,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -145,6 +150,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"outputs": outputs, "outputs": outputs,
} }
config.mem_len = 0
model = XLNetModel(config)
model.eval()
no_mems_outputs = model(input_ids_1)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["outputs"].size()), list(result["outputs"].size()),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
...@@ -312,7 +323,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -312,7 +323,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/transformers_test/"
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
......
...@@ -18,11 +18,17 @@ from __future__ import print_function ...@@ -18,11 +18,17 @@ from __future__ import print_function
import unittest import unittest
import os import os
import pytest
import torch from transformers import is_torch_available
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, if is_torch_available():
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) import torch
from transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
else:
pytestmark = pytest.mark.skip("Require Torch")
from .tokenization_tests_commons import TemporaryDirectory from .tokenization_tests_commons import TemporaryDirectory
...@@ -71,8 +77,8 @@ class OptimizationTest(unittest.TestCase): ...@@ -71,8 +77,8 @@ class OptimizationTest(unittest.TestCase):
class ScheduleInitTest(unittest.TestCase): class ScheduleInitTest(unittest.TestCase):
m = torch.nn.Linear(50, 50) m = torch.nn.Linear(50, 50) if is_torch_available() else None
optimizer = AdamW(m.parameters(), lr=10.) optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None
num_steps = 10 num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol): def assertListAlmostEqual(self, list1, list2, tol):
......
...@@ -21,21 +21,20 @@ import shutil ...@@ -21,21 +21,20 @@ import shutil
import pytest import pytest
import logging import logging
from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
class AutoTokenizerTest(unittest.TestCase): class AutoTokenizerTest(unittest.TestCase):
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
self.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, BertTokenizer) self.assertIsInstance(tokenizer, BertTokenizer)
self.assertGreater(len(tokenizer), 0) self.assertGreater(len(tokenizer), 0)
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
self.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, GPT2Tokenizer) self.assertIsInstance(tokenizer, GPT2Tokenizer)
......
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
from io import open from io import open
from pytorch_transformers.tokenization_bert import (BasicTokenizer, from transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer, BertTokenizer,
WordpieceTokenizer, WordpieceTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
...@@ -128,11 +128,11 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -128,11 +128,11 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == [101] + text + [102] assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102] assert encoded_pair == [101] + text + [102] + text_2 + [102]
......
# coding=utf-8
# Copyright 2018 Salesforce and HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
import json
from io import open
from transformers.tokenization_ctrl import CTRLTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases
class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = CTRLTokenizer
def setUp(self):
super(CTRLTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ['adapt', 're@@', 'a@@', 'apt', 'c@@', 't', '<unk>']
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", 'a p', 'ap t</w>', 'r e', 'a d', 'ad apt</w>', '']
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = u"adapt react readapt apt"
output_text = u"adapt react readapt apt"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = CTRLTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "adapt react readapt apt"
bpe_tokens = 'adapt re@@ a@@ c@@ t re@@ adapt apt'.split()
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__':
unittest.main()
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
from io import open from io import open
from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer) from transformers.tokenization_distilbert import (DistilBertTokenizer)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
from .tokenization_bert_test import BertTokenizationTest from .tokenization_bert_test import BertTokenizationTest
...@@ -33,14 +33,16 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -33,14 +33,16 @@ class DistilBertTokenizationTest(BertTokenizationTest):
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id]
assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \
text_2 + [tokenizer.sep_token_id]
assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import json import json
from io import open from io import open
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
...@@ -52,14 +52,14 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -52,14 +52,14 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u" lower newer" output_text = u"lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer" text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text, add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
......
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
import json import json
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
......
...@@ -19,7 +19,7 @@ import json ...@@ -19,7 +19,7 @@ import json
import unittest import unittest
from io import open from io import open
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
...@@ -51,14 +51,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -51,14 +51,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u" lower newer" output_text = u"lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer" text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text, add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
...@@ -70,25 +70,25 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -70,25 +70,25 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
self.assertListEqual( self.assertListEqual(
tokenizer.encode('Hello world!'), tokenizer.encode('Hello world!', add_special_tokens=False),
[0, 31414, 232, 328, 2] [0, 31414, 232, 328, 2]
) )
self.assertListEqual( self.assertListEqual(
tokenizer.encode('Hello world! cécé herlolip 418'), tokenizer.encode('Hello world! cécé herlolip 418', add_special_tokens=False),
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
) )
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = RobertaTokenizer.from_pretrained("roberta-base") tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == encoded_text_from_decode assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode assert encoded_pair == encoded_pair_from_decode
......
...@@ -79,13 +79,13 @@ class CommonTestCases: ...@@ -79,13 +79,13 @@ class CommonTestCases:
# Now let's start the test # Now let's start the test
tokenizer = self.get_tokenizer(max_len=42) tokenizer = self.get_tokenizer(max_len=42)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
self.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42) self.assertEqual(tokenizer.max_len, 42)
...@@ -130,7 +130,7 @@ class CommonTestCases: ...@@ -130,7 +130,7 @@ class CommonTestCases:
self.assertEqual(added_toks, len(new_toks)) self.assertEqual(added_toks, len(new_toks))
self.assertEqual(all_size_2, all_size + len(new_toks)) self.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l") tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
out_string = tokenizer.decode(tokens) out_string = tokenizer.decode(tokens)
self.assertGreaterEqual(len(tokens), 4) self.assertGreaterEqual(len(tokens), 4)
...@@ -148,7 +148,8 @@ class CommonTestCases: ...@@ -148,7 +148,8 @@ class CommonTestCases:
self.assertEqual(added_toks_2, len(new_toks_2)) self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l",
add_special_tokens=False)
out_string = tokenizer.decode(tokens) out_string = tokenizer.decode(tokens)
self.assertGreaterEqual(len(tokens), 6) self.assertGreaterEqual(len(tokens), 6)
...@@ -166,7 +167,7 @@ class CommonTestCases: ...@@ -166,7 +167,7 @@ class CommonTestCases:
tokens = tokenizer.tokenize(input_text) tokens = tokenizer.tokenize(input_text)
ids = tokenizer.convert_tokens_to_ids(tokens) ids = tokenizer.convert_tokens_to_ids(tokens)
ids_2 = tokenizer.encode(input_text) ids_2 = tokenizer.encode(input_text, add_special_tokens=False)
self.assertListEqual(ids, ids_2) self.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids) tokens_2 = tokenizer.convert_ids_to_tokens(ids)
...@@ -186,3 +187,131 @@ class CommonTestCases: ...@@ -186,3 +187,131 @@ class CommonTestCases:
for weights_list_2 in weights_lists_2: for weights_list_2 in weights_lists_2:
self.assertListEqual(weights_list, weights_list_2) self.assertListEqual(weights_list, weights_list_2)
def test_mask_output(self):
if sys.version_info <= (3, 0):
return
tokenizer = self.get_tokenizer()
if tokenizer.build_inputs_with_special_tokens.__qualname__.split('.')[0] != "PreTrainedTokenizer":
seq_0 = "Test this method."
seq_1 = "With these inputs."
information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True)
sequences, mask = information["input_ids"], information["token_type_ids"]
self.assertEqual(len(sequences), len(mask))
def test_number_of_added_tokens(self):
tokenizer = self.get_tokenizer()
seq_0 = "Test this method."
seq_1 = "With these inputs."
sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=False)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
# Method is implemented (e.g. not GPT-2)
if len(attached_sequences) != 2:
self.assertEqual(tokenizer.num_added_tokens(pair=True), len(attached_sequences) - len(sequences))
def test_maximum_encoding_length_single_input(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
stride = 2
sequence = tokenizer.encode(seq_0, add_special_tokens=False)
num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride)
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, sequence[-(2 + stride):])
self.assertEqual(len(truncated_sequence), total_length - 2)
self.assertEqual(truncated_sequence, tokenizer.build_inputs_with_special_tokens(sequence[:-2]))
def test_maximum_encoding_length_pair_input(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
seq_1 = "This is another sentence to be encoded."
stride = 2
sequence_0_no_special_tokens = tokenizer.encode(seq_0, add_special_tokens=False)
sequence_1_no_special_tokens = tokenizer.encode(seq_1, add_special_tokens=False)
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.build_inputs_with_special_tokens(
tokenizer.encode(seq_0, add_special_tokens=False),
tokenizer.encode(seq_1, add_special_tokens=False)[:-2]
)
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
stride=stride, truncation_strategy='only_second')
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride,
truncation_strategy='only_first')
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, sequence_1_no_special_tokens[-(2 + stride):])
self.assertEqual(overflowing_tokens_first_truncated, sequence_0_no_special_tokens[-(2 + stride):])
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
self.assertEqual(truncated_sequence, truncated_second_sequence)
def test_encode_input_type(self):
tokenizer = self.get_tokenizer()
sequence = "Let's encode this sequence"
tokens = tokenizer.tokenize(sequence)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
formatted_input = tokenizer.encode(sequence, add_special_tokens=True)
self.assertEqual(tokenizer.encode(tokens, add_special_tokens=True), formatted_input)
self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input)
def test_special_tokens_mask(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
sequence_1 = "This one too please."
# Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence)
# Testing inputs pairs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(sequence_1,
add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
filtered_sequence = [(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence)
# Testing with already existing special tokens
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True)
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
self.assertEqual(special_tokens_mask_orig, special_tokens_mask)
...@@ -16,15 +16,22 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,15 +16,22 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import pytest
from io import open from io import open
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES from transformers import is_torch_available
from.tokenization_tests_commons import CommonTestCases if is_torch_available():
import torch
from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
else:
pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save
from .tokenization_tests_commons import CommonTestCases
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = TransfoXLTokenizer tokenizer_class = TransfoXLTokenizer if is_torch_available() else None
def setUp(self): def setUp(self):
super(TransfoXLTokenizationTest, self).setUp() super(TransfoXLTokenizationTest, self).setUp()
......
...@@ -19,8 +19,8 @@ from __future__ import print_function ...@@ -19,8 +19,8 @@ from __future__ import print_function
import unittest import unittest
import six import six
from pytorch_transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer from transformers.tokenization_gpt2 import GPT2Tokenizer
class TokenizerUtilsTest(unittest.TestCase): class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class): def check_tokenizer_from_pretrained(self, tokenizer_class):
......
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
import json import json
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
...@@ -69,11 +69,11 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -69,11 +69,11 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == [1] + text + [1] assert encoded_sentence == [1] + text + [1]
assert encoded_pair == [1] + text + [1] + text_2 + [1] assert encoded_pair == [1] + text + [1] + text_2 + [1]
......
...@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
...@@ -92,11 +92,11 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -92,11 +92,11 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == text + [4, 3] assert encoded_sentence == text + [4, 3]
assert encoded_pair == text + [4] + text_2 + [4, 3] assert encoded_pair == text + [4] + text_2 + [4, 3]
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
from .tokenization_bert import BertTokenizer from .tokenization_bert import BertTokenizer
from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_xlnet import XLNetTokenizer from .tokenization_xlnet import XLNetTokenizer
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
...@@ -30,7 +31,7 @@ from .tokenization_distilbert import DistilBertTokenizer ...@@ -30,7 +31,7 @@ from .tokenization_distilbert import DistilBertTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AutoTokenizer(object): class AutoTokenizer(object):
r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library that will be instantiated as one of the tokenizer classes of the library
when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
...@@ -45,6 +46,7 @@ class AutoTokenizer(object): ...@@ -45,6 +46,7 @@ class AutoTokenizer(object):
- contains `bert`: BertTokenizer (Bert model) - contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model) - contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model) - contains `xlm`: XLMTokenizer (XLM model)
...@@ -67,6 +69,7 @@ class AutoTokenizer(object): ...@@ -67,6 +69,7 @@ class AutoTokenizer(object):
- contains `bert`: BertTokenizer (Bert model) - contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model) - contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model) - contains `xlm`: XLMTokenizer (XLM model)
...@@ -75,7 +78,7 @@ class AutoTokenizer(object): ...@@ -75,7 +78,7 @@ class AutoTokenizer(object):
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
...@@ -90,7 +93,7 @@ class AutoTokenizer(object): ...@@ -90,7 +93,7 @@ class AutoTokenizer(object):
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
Examples:: Examples::
...@@ -114,7 +117,8 @@ class AutoTokenizer(object): ...@@ -114,7 +117,8 @@ class AutoTokenizer(object):
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'xlm' in pretrained_model_name_or_path: elif 'xlm' in pretrained_model_name_or_path:
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
elif 'ctrl' in pretrained_model_name_or_path:
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path)) "'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path))
...@@ -44,6 +44,8 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -44,6 +44,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
} }
} }
...@@ -61,6 +63,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -61,6 +63,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-large-uncased-whole-word-masking-finetuned-squad': 512, 'bert-large-uncased-whole-word-masking-finetuned-squad': 512,
'bert-large-cased-whole-word-masking-finetuned-squad': 512, 'bert-large-cased-whole-word-masking-finetuned-squad': 512,
'bert-base-cased-finetuned-mrpc': 512, 'bert-base-cased-finetuned-mrpc': 512,
'bert-base-german-dbmdz-cased': 512,
'bert-base-german-dbmdz-uncased': 512,
} }
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
...@@ -77,6 +81,8 @@ PRETRAINED_INIT_CONFIGURATION = { ...@@ -77,6 +81,8 @@ PRETRAINED_INIT_CONFIGURATION = {
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True}, 'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False}, 'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False}, 'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
'bert-base-german-dbmdz-cased': {'do_lower_case': False},
'bert-base-german-dbmdz-uncased': {'do_lower_case': True},
} }
...@@ -103,7 +109,7 @@ def whitespace_tokenize(text): ...@@ -103,7 +109,7 @@ def whitespace_tokenize(text):
class BertTokenizer(PreTrainedTokenizer): class BertTokenizer(PreTrainedTokenizer):
r""" r"""
Constructs a BertTokenizer. Constructs a BertTokenizer.
:class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece :class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file vocab_file: Path to a one-wordpiece-per-line vocabulary file
...@@ -187,21 +193,60 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -187,21 +193,60 @@ class BertTokenizer(PreTrainedTokenizer):
out_string = ' '.join(tokens).replace(' ##', '').strip() out_string = ' '.join(tokens).replace(' ##', '').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
""" """
Adds special tokens to the a sequence for sequence classification tasks. Build model inputs from a sequence or a pair of sequence for sequence classification tasks
A BERT sequence has the following format: [CLS] X [SEP] by concatenating and adding special tokens.
A BERT sequence has the following format:
single sequence: [CLS] X [SEP]
pair of sequences: [CLS] A [SEP] B [SEP]
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
already_has_special_tokens: (default False) Set to True if the token list is already formated with
special tokens for the model
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
""" """
Adds special tokens to a sequence pair for sequence classification tasks.
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A BERT sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
if token_ids_1 is None, only returns the first portion of the mask (0's).
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
......
# coding=utf-8
# Copyright 2018 Salesforce and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Salesforce CTRL."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import json
import logging
import os
import regex as re
from io import open
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json",
},
'merges_file':
{
'ctrl': "https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'ctrl': 256,
}
CONTROL_CODES = {
"Pregnancy": 168629,
"Christianity": 7675,
"Explain": 106423,
"Fitness": 63440,
"Saving": 63163,
"Ask": 27171,
"Ass": 95985,
"Joke": 163509,
"Questions": 45622,
"Thoughts": 49605,
"Retail": 52342,
"Feminism": 164338,
"Writing": 11992,
"Atheism": 192263,
"Netflix": 48616,
"Computing": 39639,
"Opinion": 43213,
"Alone": 44967,
"Funny": 58917,
"Gaming": 40358,
"Human": 4088,
"India": 1331,
"Joker": 77138,
"Diet": 36206,
"Legal": 11859,
"Norman": 4939,
"Tip": 72689,
"Weight": 52343,
"Movies": 46273,
"Running": 23425,
"Science": 2090,
"Horror": 37793,
"Confession": 60572,
"Finance": 12250,
"Politics": 16360,
"Scary": 191985,
"Support": 12654,
"Technologies": 32516,
"Teenage": 66160,
"Event": 32769,
"Learned": 67460,
"Notion": 182770,
"Wikipedia": 37583,
"Books": 6665,
"Extract": 76050,
"Confessions": 102701,
"Conspiracy": 75932,
"Links": 63674,
"Narcissus": 150425,
"Relationship": 54766,
"Relationships": 134796,
"Reviews": 41671,
"News": 4256,
"Translation": 26820,
"multilingual": 128406,
}
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
pairs = set(pairs)
return pairs
class CTRLTokenizer(PreTrainedTokenizer):
"""
CTRL BPE tokenizer. Peculiarities:
- Byte-Pair-Encoding
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
control_codes = CONTROL_CODES
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
@property
def vocab_size(self):
return len(self.encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
word = tuple(list(word[:-1]) + [word[-1]+'</w>'])
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = '@@ '.join(word)
word = word[:-4]
self.cache[token] = word
return word
def _tokenize(self, text):
""" Tokenize a string.
"""
split_tokens = []
text = text.split(' ')
for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).replace('@@ ', '').strip()
return out_string
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
return vocab_file, merge_file
# def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
# filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens))
# tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens)
# tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far)
# return ''.join(tokens_generated_so_far)
...@@ -45,7 +45,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -45,7 +45,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
class DistilBertTokenizer(BertTokenizer): class DistilBertTokenizer(BertTokenizer):
r""" r"""
Constructs a DistilBertTokenizer. Constructs a DistilBertTokenizer.
:class:`~pytorch_transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece :class:`~transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file vocab_file: Path to a one-wordpiece-per-line vocabulary file
......
...@@ -46,12 +46,14 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -46,12 +46,14 @@ PRETRAINED_VOCAB_FILES_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
}, },
'merges_file': 'merges_file':
{ {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
}, },
} }
...@@ -59,6 +61,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -59,6 +61,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024, 'gpt2': 1024,
'gpt2-medium': 1024, 'gpt2-medium': 1024,
'gpt2-large': 1024, 'gpt2-large': 1024,
'distilgpt2': 1024,
} }
@lru_cache() @lru_cache()
...@@ -101,9 +104,10 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -101,9 +104,10 @@ class GPT2Tokenizer(PreTrainedTokenizer):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
- Byte-level Byte-Pair-Encoding - Byte-level Byte-Pair-Encoding
- Requires a space to start the input string => will add a space is there isn't. - Requires a space to start the input string => the encoding methods should be called with the
As a consequence, this tokenizer `encode` and `decode` method will not conserve ``add_prefix_space`` flag set to ``True``.
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello" Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
...@@ -173,9 +177,15 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -173,9 +177,15 @@ class GPT2Tokenizer(PreTrainedTokenizer):
self.cache[token] = word self.cache[token] = word
return word return word
def _tokenize(self, text): def _tokenize(self, text, add_prefix_space=False):
""" Tokenize a string. """ """ Tokenize a string.
text = ' ' + text # GPT-2 (and RoBERTa) tokenizers need at least one space to begin the sentence with. Args:
- add_prefix_space (boolean, default False):
Begin the sentence with at least one space toto get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
"""
if add_prefix_space:
text = ' ' + text
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment