Unverified Commit d3f24dfa authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Merge branch 'master' into master

parents 4b543c30 ecc4f1bd
......@@ -23,10 +23,15 @@ import random
import shutil
import pytest
import torch
from transformers import is_torch_available
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
if is_torch_available():
import torch
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
......@@ -34,7 +39,7 @@ from .configuration_common_test import ConfigTester
class XLNetModelTest(CommonTestCases.CommonModelTester):
all_model_classes=(XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering)
XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available() else ()
test_pruning = False
class XLNetModelTester(object):
......@@ -312,7 +317,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
cache_dir = "/tmp/transformers_test/"
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
......
......@@ -18,11 +18,17 @@ from __future__ import print_function
import unittest
import os
import pytest
import torch
from transformers import is_torch_available
from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
if is_torch_available():
import torch
from transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule,
WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
else:
pytestmark = pytest.mark.skip("Require Torch")
from .tokenization_tests_commons import TemporaryDirectory
......@@ -71,8 +77,8 @@ class OptimizationTest(unittest.TestCase):
class ScheduleInitTest(unittest.TestCase):
m = torch.nn.Linear(50, 50)
optimizer = AdamW(m.parameters(), lr=10.)
m = torch.nn.Linear(50, 50) if is_torch_available() else None
optimizer = AdamW(m.parameters(), lr=10.) if is_torch_available() else None
num_steps = 10
def assertListAlmostEqual(self, list1, list2, tol):
......
......@@ -21,21 +21,20 @@ import shutil
import pytest
import logging
from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer
from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
class AutoTokenizerTest(unittest.TestCase):
def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, BertTokenizer)
self.assertGreater(len(tokenizer), 0)
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, GPT2Tokenizer)
......
......@@ -18,7 +18,7 @@ import os
import unittest
from io import open
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
from transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer,
WordpieceTokenizer,
_is_control, _is_punctuation,
......@@ -131,8 +131,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102]
......
......@@ -18,7 +18,7 @@ import os
import unittest
from io import open
from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer)
from transformers.tokenization_distilbert import (DistilBertTokenizer)
from .tokenization_tests_commons import CommonTestCases
from .tokenization_bert_test import BertTokenizationTest
......@@ -36,11 +36,13 @@ class DistilBertTokenizationTest(BertTokenizationTest):
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id]
assert encoded_pair == [tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] + \
text_2 + [tokenizer.sep_token_id]
assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102]
if __name__ == '__main__':
unittest.main()
......@@ -19,7 +19,7 @@ import unittest
import json
from io import open
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases
......@@ -52,14 +52,14 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self):
input_text = u"lower newer"
output_text = u" lower newer"
output_text = u"lower newer"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text)
tokens = tokenizer.tokenize(text, add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
......
......@@ -18,7 +18,7 @@ import os
import unittest
import json
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases
......
......@@ -19,7 +19,7 @@ import json
import unittest
from io import open
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases
......@@ -51,14 +51,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self):
input_text = u"lower newer"
output_text = u" lower newer"
output_text = u"lower newer"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text)
tokens = tokenizer.tokenize(text, add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
......@@ -87,8 +87,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode
......
......@@ -186,3 +186,92 @@ class CommonTestCases:
for weights_list_2 in weights_lists_2:
self.assertListEqual(weights_list, weights_list_2)
def test_mask_output(self):
if sys.version_info <= (3, 0):
return
tokenizer = self.get_tokenizer()
if tokenizer.add_special_tokens_sequence_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
seq_0 = "Test this method."
seq_1 = "With these inputs."
information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True)
sequences, mask = information["input_ids"], information["token_type_ids"]
assert len(sequences) == len(mask)
def test_number_of_added_tokens(self):
tokenizer = self.get_tokenizer()
seq_0 = "Test this method."
seq_1 = "With these inputs."
sequences = tokenizer.encode(seq_0, seq_1)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
# Method is implemented (e.g. not GPT-2)
if len(attached_sequences) != 2:
assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - len(sequences)
def test_maximum_encoding_length_single_input(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
stride = 2
sequence = tokenizer.encode(seq_0)
num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride)
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
assert len(overflowing_tokens) == 2 + stride
assert overflowing_tokens == sequence[-(2 + stride):]
assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
def test_maximum_encoding_length_pair_input(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
seq_1 = "This is another sentence to be encoded."
stride = 2
sequence_0_no_special_tokens = tokenizer.encode(seq_0)
sequence_1_no_special_tokens = tokenizer.encode(seq_1)
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
tokenizer.encode(seq_0),
tokenizer.encode(seq_1)[:-2]
)
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
stride=stride, truncate_first_sequence=False)
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride,
truncate_first_sequence=True)
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
assert len(overflowing_tokens) == 2 + stride
assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):]
assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):]
assert len(truncated_sequence) == len(sequence) - 2
assert truncated_sequence == truncated_second_sequence
def test_encode_input_type(self):
tokenizer = self.get_tokenizer()
sequence = "Let's encode this sequence"
tokens = tokenizer.tokenize(sequence)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
formatted_input = tokenizer.encode(sequence, add_special_tokens=True)
assert tokenizer.encode(tokens, add_special_tokens=True) == formatted_input
assert tokenizer.encode(input_ids, add_special_tokens=True) == formatted_input
......@@ -16,15 +16,22 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
import pytest
from io import open
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from transformers import is_torch_available
from.tokenization_tests_commons import CommonTestCases
if is_torch_available():
import torch
from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
else:
pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save
from .tokenization_tests_commons import CommonTestCases
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = TransfoXLTokenizer
tokenizer_class = TransfoXLTokenizer if is_torch_available() else None
def setUp(self):
super(TransfoXLTokenizationTest, self).setUp()
......
......@@ -19,8 +19,8 @@ from __future__ import print_function
import unittest
import six
from pytorch_transformers import PreTrainedTokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
from transformers import PreTrainedTokenizer
from transformers.tokenization_gpt2 import GPT2Tokenizer
class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class):
......
......@@ -18,7 +18,7 @@ import os
import unittest
import json
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases
......@@ -72,8 +72,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == [1] + text + [1]
assert encoded_pair == [1] + text + [1] + text_2 + [1]
......
......@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os
import unittest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from .tokenization_tests_commons import CommonTestCases
......@@ -95,8 +95,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == text + [4, 3]
assert encoded_pair == text + [4] + text_2 + [4, 3]
......
......@@ -30,7 +30,7 @@ from .tokenization_distilbert import DistilBertTokenizer
logger = logging.getLogger(__name__)
class AutoTokenizer(object):
r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library
when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
class method.
......@@ -75,7 +75,7 @@ class AutoTokenizer(object):
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
cache_dir: (`optional`) string:
......@@ -90,7 +90,7 @@ class AutoTokenizer(object):
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
Examples::
......
......@@ -103,7 +103,7 @@ def whitespace_tokenize(text):
class BertTokenizer(PreTrainedTokenizer):
r"""
Constructs a BertTokenizer.
:class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
:class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
......@@ -187,22 +187,35 @@ class BertTokenizer(PreTrainedTokenizer):
out_string = ' '.join(tokens).replace(' ##', '').strip()
return out_string
def add_special_tokens_single_sentence(self, token_ids):
def add_special_tokens_single_sequence(self, token_ids):
"""
Adds special tokens to the a sequence for sequence classification tasks.
A BERT sequence has the following format: [CLS] X [SEP]
"""
return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A BERT sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
......
......@@ -45,7 +45,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
class DistilBertTokenizer(BertTokenizer):
r"""
Constructs a DistilBertTokenizer.
:class:`~pytorch_transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece
:class:`~transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
......
......@@ -46,12 +46,14 @@ PRETRAINED_VOCAB_FILES_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-vocab.json",
},
'merges_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
'distilgpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-merges.txt",
},
}
......@@ -59,6 +61,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024,
'gpt2-medium': 1024,
'gpt2-large': 1024,
'distilgpt2': 1024,
}
@lru_cache()
......@@ -101,9 +104,10 @@ class GPT2Tokenizer(PreTrainedTokenizer):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level Byte-Pair-Encoding
- Requires a space to start the input string => will add a space is there isn't.
As a consequence, this tokenizer `encode` and `decode` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"
- Requires a space to start the input string => the encoding methods should be called with the
``add_prefix_space`` flag set to ``True``.
Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
......@@ -173,9 +177,15 @@ class GPT2Tokenizer(PreTrainedTokenizer):
self.cache[token] = word
return word
def _tokenize(self, text):
""" Tokenize a string. """
text = ' ' + text # GPT-2 (and RoBERTa) tokenizers need at least one space to begin the sentence with.
def _tokenize(self, text, add_prefix_space=False):
""" Tokenize a string.
Args:
- add_prefix_space (boolean, default False):
Begin the sentence with at least one space toto get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
"""
if add_prefix_space:
text = ' ' + text
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
......
......@@ -66,9 +66,10 @@ class RobertaTokenizer(GPT2Tokenizer):
"""
RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities:
- Byte-level Byte-Pair-Encoding
- Requires a space to start the input string => will add a space is there isn't.
As a consequence, this tokenizer `encode` and `decode` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"
- Requires a space to start the input string => the encoding methods should be called with the
``add_prefix_space`` flag set to ``True``.
Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
......@@ -80,15 +81,17 @@ class RobertaTokenizer(GPT2Tokenizer):
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
def add_special_tokens_single_sentence(self, token_ids):
def add_special_tokens_single_sequence(self, token_ids):
"""
Adds special tokens to a sequence for sequence classification tasks.
A RoBERTa sequence has the following format: <s> X </s>
"""
return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
A RoBERTa sequence pair has the following format: <s> A </s></s> B </s>
......@@ -96,3 +99,15 @@ class RobertaTokenizer(GPT2Tokenizer):
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A RoBERTa sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1]
\ No newline at end of file
......@@ -26,16 +26,20 @@ import sys
from collections import Counter, OrderedDict
from io import open
import torch
import numpy as np
from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
try:
import torch
except ImportError:
pass
# if sys.version_info[0] == 2:
# import cPickle as pickle
# else:
# import pickle
logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment