Commit d7395789 authored by danai-antoniou's avatar danai-antoniou
Browse files

Merge branch 'master' of...

Merge branch 'master' of https://github.com/danai-antoniou/pytorch-transformers into add-duplicate-tokens-error
parents 2e6797cc 391db836
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import json import json
from io import open from io import open
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES from transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
...@@ -52,14 +52,14 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -52,14 +52,14 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u" lower newer" output_text = u"lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer" text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text, add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
......
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
import json import json
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
......
...@@ -19,7 +19,7 @@ import json ...@@ -19,7 +19,7 @@ import json
import unittest import unittest
from io import open from io import open
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
...@@ -51,14 +51,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -51,14 +51,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u" lower newer" output_text = u"lower newer"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer" text = "lower newer"
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"] bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text, add_prefix_space=True)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token] input_tokens = tokens + [tokenizer.unk_token]
...@@ -87,8 +87,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -87,8 +87,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == encoded_text_from_decode assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode assert encoded_pair == encoded_pair_from_decode
......
...@@ -186,3 +186,92 @@ class CommonTestCases: ...@@ -186,3 +186,92 @@ class CommonTestCases:
for weights_list_2 in weights_lists_2: for weights_list_2 in weights_lists_2:
self.assertListEqual(weights_list, weights_list_2) self.assertListEqual(weights_list, weights_list_2)
def test_mask_output(self):
if sys.version_info <= (3, 0):
return
tokenizer = self.get_tokenizer()
if tokenizer.add_special_tokens_sequence_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
seq_0 = "Test this method."
seq_1 = "With these inputs."
information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True)
sequences, mask = information["input_ids"], information["token_type_ids"]
assert len(sequences) == len(mask)
def test_number_of_added_tokens(self):
tokenizer = self.get_tokenizer()
seq_0 = "Test this method."
seq_1 = "With these inputs."
sequences = tokenizer.encode(seq_0, seq_1)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
# Method is implemented (e.g. not GPT-2)
if len(attached_sequences) != 2:
assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - len(sequences)
def test_maximum_encoding_length_single_input(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
stride = 2
sequence = tokenizer.encode(seq_0)
num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride)
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
assert len(overflowing_tokens) == 2 + stride
assert overflowing_tokens == sequence[-(2 + stride):]
assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
def test_maximum_encoding_length_pair_input(self):
tokenizer = self.get_tokenizer()
seq_0 = "This is a sentence to be encoded."
seq_1 = "This is another sentence to be encoded."
stride = 2
sequence_0_no_special_tokens = tokenizer.encode(seq_0)
sequence_1_no_special_tokens = tokenizer.encode(seq_1)
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
tokenizer.encode(seq_0),
tokenizer.encode(seq_1)[:-2]
)
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
stride=stride, truncate_first_sequence=False)
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride,
truncate_first_sequence=True)
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
assert len(overflowing_tokens) == 2 + stride
assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):]
assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):]
assert len(truncated_sequence) == len(sequence) - 2
assert truncated_sequence == truncated_second_sequence
def test_encode_input_type(self):
tokenizer = self.get_tokenizer()
sequence = "Let's encode this sequence"
tokens = tokenizer.tokenize(sequence)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
formatted_input = tokenizer.encode(sequence, add_special_tokens=True)
assert tokenizer.encode(tokens, add_special_tokens=True) == formatted_input
assert tokenizer.encode(input_ids, add_special_tokens=True) == formatted_input
...@@ -16,15 +16,22 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -16,15 +16,22 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import pytest
from io import open from io import open
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES from transformers import is_torch_available
from.tokenization_tests_commons import CommonTestCases if is_torch_available():
import torch
from transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
else:
pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save
from .tokenization_tests_commons import CommonTestCases
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = TransfoXLTokenizer tokenizer_class = TransfoXLTokenizer if is_torch_available() else None
def setUp(self): def setUp(self):
super(TransfoXLTokenizationTest, self).setUp() super(TransfoXLTokenizationTest, self).setUp()
......
...@@ -19,8 +19,8 @@ from __future__ import print_function ...@@ -19,8 +19,8 @@ from __future__ import print_function
import unittest import unittest
import six import six
from pytorch_transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer from transformers.tokenization_gpt2 import GPT2Tokenizer
class TokenizerUtilsTest(unittest.TestCase): class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class): def check_tokenizer_from_pretrained(self, tokenizer_class):
......
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import unittest import unittest
import json import json
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
...@@ -72,8 +72,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -72,8 +72,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == [1] + text + [1] assert encoded_sentence == [1] + text + [1]
assert encoded_pair == [1] + text + [1] + text_2 + [1] assert encoded_pair == [1] + text + [1] + text_2 + [1]
......
...@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
...@@ -95,8 +95,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -95,8 +95,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == text + [4, 3] assert encoded_sentence == text + [4, 3]
assert encoded_pair == text + [4] + text_2 + [4, 3] assert encoded_pair == text + [4] + text_2 + [4, 3]
......
...@@ -30,7 +30,7 @@ from .tokenization_distilbert import DistilBertTokenizer ...@@ -30,7 +30,7 @@ from .tokenization_distilbert import DistilBertTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AutoTokenizer(object): class AutoTokenizer(object):
r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
that will be instantiated as one of the tokenizer classes of the library that will be instantiated as one of the tokenizer classes of the library
when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
...@@ -75,7 +75,7 @@ class AutoTokenizer(object): ...@@ -75,7 +75,7 @@ class AutoTokenizer(object):
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
...@@ -90,7 +90,7 @@ class AutoTokenizer(object): ...@@ -90,7 +90,7 @@ class AutoTokenizer(object):
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
Examples:: Examples::
......
...@@ -103,7 +103,7 @@ def whitespace_tokenize(text): ...@@ -103,7 +103,7 @@ def whitespace_tokenize(text):
class BertTokenizer(PreTrainedTokenizer): class BertTokenizer(PreTrainedTokenizer):
r""" r"""
Constructs a BertTokenizer. Constructs a BertTokenizer.
:class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece :class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file vocab_file: Path to a one-wordpiece-per-line vocabulary file
...@@ -187,22 +187,35 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -187,22 +187,35 @@ class BertTokenizer(PreTrainedTokenizer):
out_string = ' '.join(tokens).replace(' ##', '').strip() out_string = ' '.join(tokens).replace(' ##', '').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
""" """
Adds special tokens to the a sequence for sequence classification tasks. Adds special tokens to the a sequence for sequence classification tasks.
A BERT sequence has the following format: [CLS] X [SEP] A BERT sequence has the following format: [CLS] X [SEP]
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A BERT sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
index = 0 index = 0
......
...@@ -45,7 +45,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -45,7 +45,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
class DistilBertTokenizer(BertTokenizer): class DistilBertTokenizer(BertTokenizer):
r""" r"""
Constructs a DistilBertTokenizer. Constructs a DistilBertTokenizer.
:class:`~pytorch_transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece :class:`~transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file vocab_file: Path to a one-wordpiece-per-line vocabulary file
......
...@@ -101,9 +101,10 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -101,9 +101,10 @@ class GPT2Tokenizer(PreTrainedTokenizer):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
- Byte-level Byte-Pair-Encoding - Byte-level Byte-Pair-Encoding
- Requires a space to start the input string => will add a space is there isn't. - Requires a space to start the input string => the encoding methods should be called with the
As a consequence, this tokenizer `encode` and `decode` method will not conserve ``add_prefix_space`` flag set to ``True``.
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello" Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
...@@ -173,9 +174,15 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -173,9 +174,15 @@ class GPT2Tokenizer(PreTrainedTokenizer):
self.cache[token] = word self.cache[token] = word
return word return word
def _tokenize(self, text): def _tokenize(self, text, add_prefix_space=False):
""" Tokenize a string. """ """ Tokenize a string.
text = ' ' + text # GPT-2 (and RoBERTa) tokenizers need at least one space to begin the sentence with. Args:
- add_prefix_space (boolean, default False):
Begin the sentence with at least one space toto get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
"""
if add_prefix_space:
text = ' ' + text
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
......
...@@ -66,9 +66,10 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -66,9 +66,10 @@ class RobertaTokenizer(GPT2Tokenizer):
""" """
RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities:
- Byte-level Byte-Pair-Encoding - Byte-level Byte-Pair-Encoding
- Requires a space to start the input string => will add a space is there isn't. - Requires a space to start the input string => the encoding methods should be called with the
As a consequence, this tokenizer `encode` and `decode` method will not conserve ``add_prefix_space`` flag set to ``True``.
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello" Otherwise, this tokenizer ``encode`` and ``decode`` method will not conserve
the absence of a space at the beginning of a string: `tokenizer.decode(tokenizer.encode("Hello")) = " Hello"`
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
...@@ -80,15 +81,17 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -80,15 +81,17 @@ class RobertaTokenizer(GPT2Tokenizer):
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, **kwargs) mask_token=mask_token, **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.
A RoBERTa sequence has the following format: <s> X </s> A RoBERTa sequence has the following format: <s> X </s>
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
A RoBERTa sequence pair has the following format: <s> A </s></s> B </s> A RoBERTa sequence pair has the following format: <s> A </s></s> B </s>
...@@ -96,3 +99,15 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -96,3 +99,15 @@ class RobertaTokenizer(GPT2Tokenizer):
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A RoBERTa sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1]
\ No newline at end of file
...@@ -26,16 +26,20 @@ import sys ...@@ -26,16 +26,20 @@ import sys
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from io import open from io import open
import torch
import numpy as np import numpy as np
from .file_utils import cached_path from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
if sys.version_info[0] == 2: try:
import cPickle as pickle import torch
else: except ImportError:
import pickle pass
# if sys.version_info[0] == 2:
# import cPickle as pickle
# else:
# import pickle
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -23,7 +23,12 @@ import six ...@@ -23,7 +23,12 @@ import six
import copy import copy
from io import open from io import open
from .file_utils import cached_path from .file_utils import cached_path, is_tf_available, is_torch_available
if is_tf_available():
import tensorflow as tf
if is_torch_available():
import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -231,13 +236,13 @@ class PreTrainedTokenizer(object): ...@@ -231,13 +236,13 @@ class PreTrainedTokenizer(object):
@classmethod @classmethod
def from_pretrained(cls, *inputs, **kwargs): def from_pretrained(cls, *inputs, **kwargs):
r""" r"""
Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer. Instantiate a :class:`~transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
Args: Args:
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
- (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
...@@ -252,7 +257,7 @@ class PreTrainedTokenizer(object): ...@@ -252,7 +257,7 @@ class PreTrainedTokenizer(object):
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
Examples:: Examples::
...@@ -425,9 +430,9 @@ class PreTrainedTokenizer(object): ...@@ -425,9 +430,9 @@ class PreTrainedTokenizer(object):
- tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert). - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
This won't save modifications other than (added tokens and special token mapping) you may have This won't save modifications other than (added tokens and special token mapping) you may have
applied to the tokenizer after the instantion (e.g. modifying tokenizer.do_lower_case after creation). applied to the tokenizer after the instantiation (e.g. modifying tokenizer.do_lower_case after creation).
This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method. This method make sure the full tokenizer can then be re-loaded using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
""" """
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Saving directory ({}) should be a directory".format(save_directory)) logger.error("Saving directory ({}) should be a directory".format(save_directory))
...@@ -464,7 +469,7 @@ class PreTrainedTokenizer(object): ...@@ -464,7 +469,7 @@ class PreTrainedTokenizer(object):
""" Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
and special token mappings. and special token mappings.
Please use :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method. Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -521,6 +526,30 @@ class PreTrainedTokenizer(object): ...@@ -521,6 +526,30 @@ class PreTrainedTokenizer(object):
return len(to_add_tokens) return len(to_add_tokens)
def num_added_tokens(self, pair=False):
"""
Returns the number of added tokens when encoding a sequence with special tokens.
Note:
This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
inside your training loop.
Args:
pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
number of added tokens in the case of a single sequence if set to False.
Returns:
Number of tokens added to sequences
"""
if pair:
initial_tokens_len = len(self.encode("This is a sequence") + self.encode("This is another"))
final_tokens_len = len(self.encode("This is a sequence", "This is another", add_special_tokens=True))
else:
initial_tokens_len = len(self.encode("This is a sequence"))
final_tokens_len = len(self.encode("This is a sequence", add_special_tokens=True))
return final_tokens_len - initial_tokens_len
def add_special_tokens(self, special_tokens_dict): def add_special_tokens(self, special_tokens_dict):
""" """
...@@ -666,38 +695,185 @@ class PreTrainedTokenizer(object): ...@@ -666,38 +695,185 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs): def encode(self,
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
return_tensors=None,
**kwargs):
""" """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``. Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args: Args:
text: The first sequence to be encoded. text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
text_pair: Optional second sequence to be encoded. the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
method)
text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
string using the `tokenize` method) or a list of integers (tokenized string ids using the
`convert_tokens_to_ids` method)
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model. to their model.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
if text_pair is None: encoded_inputs = self.encode_plus(text,
if add_special_tokens: text_pair=text_pair,
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs))) max_length=max_length,
else: add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence,
return_tensors=return_tensors,
**kwargs)
return encoded_inputs["input_ids"]
def encode_plus(self,
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
return_tensors=None,
**kwargs):
"""
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
Args:
text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
method)
text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
string using the `tokenize` method) or a list of integers (tokenized string ids using the
`convert_tokens_to_ids` method)
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method
"""
def get_input_ids(text):
if isinstance(text, six.string_types):
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], six.string_types):
return self.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text
else:
raise ValueError("Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.")
first_ids = get_input_ids(text)
second_ids = get_input_ids(text_pair) if text_pair is not None else None
return self.prepare_for_model(first_ids,
pair_ids=second_ids,
max_length=max_length,
add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence,
return_tensors=return_tensors)
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] truncate_first_sequence=True, return_tensors=None):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
sequences if overflowing while taking into account the special tokens and manages a window stride for
overflowing tokens
Args:
ids: list of tokenized input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
`tokenize` and `convert_tokens_to_ids` methods.
max_length: maximum length of the returned list. Will truncate by taking into account the special tokens.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential
list of inputs.
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
Return:
a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
"""
pair = bool(pair_ids is not None)
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0
encoded_inputs = {}
if max_length:
n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0
if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
logger.warning(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length."
"This pair of sequences will not be truncated.")
else:
if n_added_tokens + len_ids + len_pair_ids > max_length:
if truncate_first_sequence or not pair:
encoded_inputs["overflowing_tokens"] = ids[max_length - len_pair_ids - n_added_tokens - stride:]
ids = ids[:max_length - len_pair_ids - n_added_tokens]
elif not truncate_first_sequence and pair:
encoded_inputs["overflowing_tokens"] = pair_ids[max_length - len_ids - n_added_tokens - stride:]
pair_ids = pair_ids[:max_length - len_ids - n_added_tokens]
else:
logger.warning(
"Cannot truncate second sequence as it is not provided. No truncation.")
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) sequence = self.add_special_tokens_sequence_pair(ids, pair_ids) if pair else self.add_special_tokens_single_sequence(ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) if pair else [0] * len(sequence)
else: else:
return first_sentence_tokens, second_sentence_tokens sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
if return_tensors == 'tf' and is_tf_available():
sequence = tf.constant([sequence])
token_type_ids = tf.constant([token_type_ids])
elif return_tensors == 'pt' and is_torch_available():
sequence = torch.tensor([sequence])
token_type_ids = torch.tensor([token_type_ids])
elif return_tensors is not None:
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
encoded_inputs["input_ids"] = sequence
encoded_inputs["token_type_ids"] = token_type_ids
def add_special_tokens_single_sentence(self, token_ids): return encoded_inputs
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
logger.warning("This tokenizer does not make use of special tokens.")
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
def add_special_tokens_single_sequence(self, token_ids):
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.") logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
return token_ids return token_ids
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.") logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
return token_ids_0 + token_ids_1 return token_ids_0 + token_ids_1
...@@ -743,7 +919,7 @@ class PreTrainedTokenizer(object): ...@@ -743,7 +919,7 @@ class PreTrainedTokenizer(object):
# To avoid mixing byte-level and unicode for byte-level BPT # To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separatly for added tokens and byte-level tokens # we need to build string separatly for added tokens and byte-level tokens
# cf. https://github.com/huggingface/pytorch-transformers/issues/1133 # cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = [] sub_texts = []
current_sub_text = [] current_sub_text = []
for token in filtered_tokens: for token in filtered_tokens:
...@@ -760,20 +936,11 @@ class PreTrainedTokenizer(object): ...@@ -760,20 +936,11 @@ class PreTrainedTokenizer(object):
sub_texts.append(self.convert_tokens_to_string(current_sub_text)) sub_texts.append(self.convert_tokens_to_string(current_sub_text))
text = ''.join(sub_texts) text = ''.join(sub_texts)
if self._sep_token is not None and self._sep_token in text: if clean_up_tokenization_spaces:
text = text.replace(self._cls_token, self._sep_token) clean_text = self.clean_up_tokenization(text)
split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self._sep_token))) return clean_text
if clean_up_tokenization_spaces:
clean_text = [self.clean_up_tokenization(text) for text in split_text]
return clean_text
else:
return split_text
else: else:
if clean_up_tokenization_spaces: return text
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
@property @property
def special_tokens_map(self): def special_tokens_map(self):
......
...@@ -754,14 +754,14 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -754,14 +754,14 @@ class XLMTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.
An XLM sequence has the following format: [CLS] X [SEP] An XLM sequence has the following format: [CLS] X [SEP]
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP] An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
...@@ -770,6 +770,18 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -770,6 +770,18 @@ class XLMTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
An XLM sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
......
...@@ -181,24 +181,38 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -181,24 +181,38 @@ class XLNetTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.
An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS] An XLNet sequence has the following format: X [SEP][CLS]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
return token_ids + sep + cls return token_ids + sep + cls
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
An XLNet sequence has the following format: X [SEP][CLS] An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
return token_ids_0 + sep + token_ids_1 + sep + cls return token_ids_0 + sep + token_ids_1 + sep + cls
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A BERT sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 2
| first sequence | second sequence | CLS segment ID
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
cls_segment_id = [2]
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file """ Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory. to a directory.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment