Commit 328afb70 authored by thomwolf's avatar thomwolf
Browse files

cleaning up tokenizer tests structure (at last) - last remaining ppb refs

parent 00132b7a
...@@ -345,8 +345,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/') ...@@ -345,8 +345,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/')
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules ### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer. The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences:
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API.
- it only implements weights decay correction,
- schedules are now externals (see below),
- gradient clipping is now also external (see below).
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping.
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore. The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
...@@ -355,6 +360,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch ...@@ -355,6 +360,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
```python ```python
# Parameters: # Parameters:
lr = 1e-3 lr = 1e-3
max_grad_norm = 1.0
num_total_steps = 1000 num_total_steps = 1000
num_warmup_steps = 100 num_warmup_steps = 100
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
...@@ -374,6 +380,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot ...@@ -374,6 +380,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot
for batch in train_data: for batch in train_data:
loss = model(batch) loss = model(batch)
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
scheduler.step() scheduler.step()
optimizer.step() optimizer.step()
``` ```
......
...@@ -68,8 +68,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/') ...@@ -68,8 +68,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/')
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules ### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer. The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences:
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API.
- it only implements weights decay correction,
- schedules are now externals (see below),
- gradient clipping is now also external (see below).
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping.
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore. The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
...@@ -78,6 +83,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch ...@@ -78,6 +83,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
```python ```python
# Parameters: # Parameters:
lr = 1e-3 lr = 1e-3
max_grad_norm = 1.0
num_total_steps = 1000 num_total_steps = 1000
num_warmup_steps = 100 num_warmup_steps = 100
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
...@@ -97,6 +103,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot ...@@ -97,6 +103,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot
for batch in train_data: for batch in train_data:
loss = model(batch) loss = model(batch)
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
scheduler.step() scheduler.step()
optimizer.step() optimizer.step()
``` ```
...@@ -122,7 +122,7 @@ Here is the recommended way of saving the model, configuration and vocabulary to ...@@ -122,7 +122,7 @@ Here is the recommended way of saving the model, configuration and vocabulary to
.. code-block:: python .. code-block:: python
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
output_dir = "./models/" output_dir = "./models/"
......
...@@ -74,7 +74,7 @@ according to a ``BertConfig`` class and then saved to disk under the filename `` ...@@ -74,7 +74,7 @@ according to a ``BertConfig`` class and then saved to disk under the filename ``
.. code-block:: python .. code-block:: python
from pytorch_pretrained_bert import BertModel, BertTokenizer, BertConfig from pytorch_transformers import BertModel, BertTokenizer, BertConfig
import torch import torch
enc = BertTokenizer.from_pretrained("bert-base-uncased") enc = BertTokenizer.from_pretrained("bert-base-uncased")
...@@ -105,6 +105,9 @@ according to a ``BertConfig`` class and then saved to disk under the filename `` ...@@ -105,6 +105,9 @@ according to a ``BertConfig`` class and then saved to disk under the filename ``
# The model needs to be in evaluation mode # The model needs to be in evaluation mode
model.eval() model.eval()
# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# Creating the trace # Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt") torch.jit.save(traced_model, "traced_bert.pt")
......
...@@ -39,4 +39,4 @@ from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, ...@@ -39,4 +39,4 @@ from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
...@@ -20,7 +20,7 @@ import argparse ...@@ -20,7 +20,7 @@ import argparse
import torch import torch
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from pytorch_pretrained_bert.modeling import BertModel from pytorch_transformers.modeling import BertModel
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
......
...@@ -38,10 +38,13 @@ except ImportError: ...@@ -38,10 +38,13 @@ except ImportError:
try: try:
from pathlib import Path from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path( PYTORCH_PRETRAINED_BERT_CACHE = Path(
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
except (AttributeError, ImportError): except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
default_cache_path) os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
default_cache_path))
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
...@@ -70,7 +73,7 @@ def filename_to_url(filename, cache_dir=None): ...@@ -70,7 +73,7 @@ def filename_to_url(filename, cache_dir=None):
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
...@@ -98,7 +101,7 @@ def cached_path(url_or_filename, cache_dir=None): ...@@ -98,7 +101,7 @@ def cached_path(url_or_filename, cache_dir=None):
make sure the file exists and then return the path. make sure the file exists and then return the path.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename) url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
...@@ -187,7 +190,7 @@ def get_from_cache(url, cache_dir=None): ...@@ -187,7 +190,7 @@ def get_from_cache(url, cache_dir=None):
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if sys.version_info[0] == 2 and not isinstance(cache_dir, str): if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
......
...@@ -24,30 +24,37 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer, ...@@ -24,30 +24,37 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
_is_whitespace, VOCAB_FILES_NAMES) _is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
class TokenizationTest(unittest.TestCase): class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = BertTokenizer
def setUp(self):
super(BertTokenizationTest, self).setUp()
def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ",", "low", "lowest", "##ing", ",", "low", "lowest",
] ]
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
with open(vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
input_text = u"UNwant\u00E9d,running" def get_tokenizer(self):
output_text = u"unwanted, running" return BertTokenizer.from_pretrained(self.tmpdirname)
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname) def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running"
output_text = u"unwanted, running"
return input_text, output_text
tokenizer = BertTokenizer(vocab_file) def test_full_tokenizer(self):
tokenizer = BertTokenizer(self.vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
...@@ -20,42 +20,49 @@ import json ...@@ -20,42 +20,49 @@ import json
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
class GPT2TokenizationTest(unittest.TestCase): class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_full_tokenizer(self): tokenizer_class = GPT2Tokenizer
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
def setUp(self):
super(GPT2TokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er", "lo", "low", "er",
"low", "lowest", "newer", "wider", "<unk>"] "low", "lowest", "newer", "wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""] merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
special_tokens_map = {"unk_token": "<unk>"} self.special_tokens_map = {"unk_token": "<unk>"}
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(self.vocab_file, "w") as fp:
with open(vocab_file, "w") as fp: fp.write(json.dumps(vocab_tokens))
fp.write(json.dumps(vocab_tokens)) with open(self.merges_file, "w") as fp:
with open(merges_file, "w") as fp: fp.write("\n".join(merges))
fp.write("\n".join(merges))
def get_tokenizer(self):
input_text = u"lower newer" return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
output_text = u"lower<unk>newer"
def get_input_output_texts(self):
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map) input_text = u"lower newer"
output_text = u"lower<unk>newer"
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map) return input_text, output_text
text = "lower"
bpe_tokens = ["low", "er"] def test_full_tokenizer(self):
tokens = tokenizer.tokenize(text) tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
self.assertListEqual(tokens, bpe_tokens) text = "lower"
bpe_tokens = ["low", "er"]
input_tokens = tokens + [tokenizer.unk_token] tokens = tokenizer.tokenize(text)
input_bpe_tokens = [13, 12, 17] self.assertListEqual(tokens, bpe_tokens)
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [13, 12, 17]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,13 +20,17 @@ import json ...@@ -20,13 +20,17 @@ import json
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
class OpenAIGPTTokenizationTest(unittest.TestCase): class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_full_tokenizer(self): tokenizer_class = OpenAIGPTTokenizer
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
def setUp(self):
super(OpenAIGPTTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>", "w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>", "lo", "low", "er</w>",
...@@ -34,30 +38,34 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -34,30 +38,34 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""] merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(self.vocab_file, "w") as fp:
with open(vocab_file, "w") as fp: fp.write(json.dumps(vocab_tokens))
fp.write(json.dumps(vocab_tokens)) with open(self.merges_file, "w") as fp:
with open(merges_file, "w") as fp: fp.write("\n".join(merges))
fp.write("\n".join(merges))
input_text = u"lower newer" def get_tokenizer(self):
output_text = u"lower newer" return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname)
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname) def get_input_output_texts(self):
input_text = u"lower newer"
output_text = u"lower newer"
return input_text, output_text
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
text = "lower" def test_full_tokenizer(self):
bpe_tokens = ["low", "er</w>"] tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens) text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"] input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20] input_bpe_tokens = [14, 15, 20]
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -19,6 +19,7 @@ import sys ...@@ -19,6 +19,7 @@ import sys
from io import open from io import open
import tempfile import tempfile
import shutil import shutil
import unittest
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
...@@ -36,113 +37,124 @@ else: ...@@ -36,113 +37,124 @@ else:
unicode = str unicode = str
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs): class CommonTestCases:
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") class CommonTokenizerTester(unittest.TestCase):
with TemporaryDirectory() as tmpdirname: tokenizer_class = None
tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") def setUp(self):
tester.assertListEqual(before_tokens, after_tokens) self.tmpdirname = tempfile.mkdtemp()
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs): def tearDown(self):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) shutil.rmtree(self.tmpdirname)
tester.assertIsNotNone(tokenizer)
text = u"Munich and Berlin are nice cities" def get_tokenizer(self):
subwords = tokenizer.tokenize(text) raise NotImplementedError
with TemporaryDirectory() as tmpdirname: def get_input_output_texts(self):
raise NotImplementedError
filename = os.path.join(tmpdirname, u"tokenizer.bin") def test_save_and_load_tokenizer(self):
pickle.dump(tokenizer, open(filename, "wb")) tokenizer = self.get_tokenizer()
tokenizer_new = pickle.load(open(filename, "rb")) before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
subwords_loaded = tokenizer_new.tokenize(text) with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname)
tester.assertListEqual(subwords, subwords_loaded) after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
self.assertListEqual(before_tokens, after_tokens)
def test_pickle_tokenizer(self):
tokenizer = self.get_tokenizer()
self.assertIsNotNone(tokenizer)
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs): text = u"Munich and Berlin are nice cities"
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) subwords = tokenizer.tokenize(text)
vocab_size = tokenizer.vocab_size with TemporaryDirectory() as tmpdirname:
all_size = len(tokenizer)
tester.assertNotEqual(vocab_size, 0) filename = os.path.join(tmpdirname, u"tokenizer.bin")
tester.assertEqual(vocab_size, all_size) pickle.dump(tokenizer, open(filename, "wb"))
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] tokenizer_new = pickle.load(open(filename, "rb"))
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
tester.assertNotEqual(vocab_size_2, 0) subwords_loaded = tokenizer_new.tokenize(text)
tester.assertEqual(vocab_size, vocab_size_2)
tester.assertEqual(added_toks, len(new_toks))
tester.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") self.assertListEqual(subwords, subwords_loaded)
tester.assertGreaterEqual(len(tokens), 4)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
'pad_token': "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
tester.assertNotEqual(vocab_size_3, 0) def test_add_tokens_tokenizer(self):
tester.assertEqual(vocab_size, vocab_size_3) tokenizer = self.get_tokenizer()
tester.assertEqual(added_toks_2, len(new_toks_2))
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
tester.assertGreaterEqual(len(tokens), 6) self.assertNotEqual(vocab_size, 0)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertEqual(vocab_size, all_size)
tester.assertGreater(tokens[0], tokens[1])
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokens[-3])
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): self.assertNotEqual(vocab_size_2, 0)
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) self.assertEqual(vocab_size, vocab_size_2)
self.assertEqual(added_toks, len(new_toks))
self.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.tokenize(input_text) tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
ids = tokenizer.convert_tokens_to_ids(tokens) self.assertGreaterEqual(len(tokens), 4)
ids_2 = tokenizer.encode(input_text) self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertListEqual(ids, ids_2) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
tokens_2 = tokenizer.convert_ids_to_tokens(ids) new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
text_2 = tokenizer.decode(ids) 'pad_token': "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
tester.assertEqual(text_2, output_text) self.assertNotEqual(vocab_size_3, 0)
self.assertEqual(vocab_size, vocab_size_3)
self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tester.assertNotEqual(len(tokens_2), 0) tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
tester.assertIsInstance(text_2, (str, unicode))
self.assertGreaterEqual(len(tokens), 6)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[0], tokens[1])
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-2], tokens[-3])
self.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
self.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
def create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
weights_list = list(tokenizer_class.max_model_input_sizes.keys())
weights_lists_2 = []
for file_id, map_list in tokenizer_class.pretrained_vocab_files_map.items():
weights_lists_2.append(list(map_list.keys()))
for weights_list_2 in weights_lists_2: def test_required_methods_tokenizer(self):
tester.assertListEqual(weights_list, weights_list_2) tokenizer = self.get_tokenizer()
input_text, output_text = self.get_input_output_texts()
tokens = tokenizer.tokenize(input_text)
ids = tokenizer.convert_tokens_to_ids(tokens)
ids_2 = tokenizer.encode(input_text)
self.assertListEqual(ids, ids_2)
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): tokens_2 = tokenizer.convert_ids_to_tokens(ids)
create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs) text_2 = tokenizer.decode(ids)
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs) self.assertEqual(text_2, output_text)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs) self.assertNotEqual(len(tokens_2), 0)
self.assertIsInstance(text_2, (str, unicode))
def test_pretrained_model_lists(self):
weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
weights_lists_2 = []
for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items():
weights_lists_2.append(list(map_list.keys()))
for weights_list_2 in weights_lists_2:
self.assertListEqual(weights_list, weights_list_2)
...@@ -20,32 +20,39 @@ from io import open ...@@ -20,32 +20,39 @@ from io import open
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from.tokenization_tests_commons import CommonTestCases
class TransfoXLTokenizationTest(unittest.TestCase): class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = TransfoXLTokenizer
def setUp(self):
super(TransfoXLTokenizationTest, self).setUp()
def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
"running", ",", "low", "l", "running", ",", "low", "l",
] ]
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
with open(vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
input_text = u"<unk> UNwanted , running" def get_tokenizer(self):
output_text = u"<unk> unwanted, running" return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True)
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True) def get_input_output_texts(self):
input_text = u"<unk> UNwanted , running"
output_text = u"<unk> unwanted, running"
return input_text, output_text
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) def test_full_tokenizer(self):
tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running") tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"]) self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self): def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True) tokenizer = TransfoXLTokenizer(lower_case=True)
......
...@@ -20,12 +20,16 @@ import json ...@@ -20,12 +20,16 @@ import json
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
class XLMTokenizationTest(unittest.TestCase): class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_full_tokenizer(self): tokenizer_class = XLMTokenizer
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
def setUp(self):
super(XLMTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>", "w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>", "lo", "low", "er</w>",
...@@ -33,30 +37,34 @@ class XLMTokenizationTest(unittest.TestCase): ...@@ -33,30 +37,34 @@ class XLMTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""] merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(self.vocab_file, "w") as fp:
with open(vocab_file, "w") as fp: fp.write(json.dumps(vocab_tokens))
fp.write(json.dumps(vocab_tokens)) with open(self.merges_file, "w") as fp:
with open(merges_file, "w") as fp: fp.write("\n".join(merges))
fp.write("\n".join(merges))
input_text = u"lower newer" def get_tokenizer(self):
output_text = u"lower newer" return XLMTokenizer.from_pretrained(self.tmpdirname)
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname) def get_input_output_texts(self):
input_text = u"lower newer"
output_text = u"lower newer"
return input_text, output_text
tokenizer = XLMTokenizer(vocab_file, merges_file) def test_full_tokenizer(self):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
tokenizer = XLMTokenizer(self.vocab_file, self.merges_file)
text = "lower" text = "lower"
bpe_tokens = ["low", "er</w>"] bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"] input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20] input_bpe_tokens = [14, 15, 20]
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -19,48 +19,58 @@ import unittest ...@@ -19,48 +19,58 @@ import unittest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model') 'fixtures/test_sentencepiece.model')
class XLNetTokenizationTest(unittest.TestCase): class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = XLNetTokenizer
def setUp(self):
super(XLNetTokenizationTest, self).setUp()
# We have a SentencePiece fixture for testing
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self):
return XLNetTokenizer.from_pretrained(self.tmpdirname)
def get_input_output_texts(self):
input_text = u"This is a test"
output_text = u"This is a test"
return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
with TemporaryDirectory() as tmpdirname: tokens = tokenizer.tokenize(u'This is a test')
tokenizer.save_pretrained(tmpdirname) self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
input_text = u"This is a test" self.assertListEqual(
output_text = u"This is a test" tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname) tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
tokens = tokenizer.tokenize(u'This is a test') u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
self.assertListEqual( ids = tokenizer.convert_tokens_to_ids(tokens)
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 602, 347, 347, 347, 3, 12, 66,
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 46, 72, 80, 6, 0, 4])
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', back_tokens = tokenizer.convert_ids_to_tokens(ids)
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
ids = tokenizer.convert_tokens_to_ids(tokens) u'or', u'n', SPIECE_UNDERLINE + u'in',
self.assertListEqual( SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
ids, [8, 21, 84, 55, 24, 19, 7, 0, SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
602, 347, 347, 347, 3, 12, 66, SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
46, 72, 80, 6, 0, 4]) u'<unk>', u'.'])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
def test_tokenizer_lower(self): def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
......
...@@ -86,7 +86,7 @@ def whitespace_tokenize(text): ...@@ -86,7 +86,7 @@ def whitespace_tokenize(text):
class BertTokenizer(PreTrainedTokenizer): class BertTokenizer(PreTrainedTokenizer):
r""" r"""
Constructs a BertTokenizer. Constructs a BertTokenizer.
:class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece :class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file vocab_file: Path to a one-wordpiece-per-line vocabulary file
......
...@@ -125,42 +125,34 @@ class PreTrainedTokenizer(object): ...@@ -125,42 +125,34 @@ class PreTrainedTokenizer(object):
@bos_token.setter @bos_token.setter
def bos_token(self, value): def bos_token(self, value):
self.add_tokens([value])
self._bos_token = value self._bos_token = value
@eos_token.setter @eos_token.setter
def eos_token(self, value): def eos_token(self, value):
self.add_tokens([value])
self._eos_token = value self._eos_token = value
@unk_token.setter @unk_token.setter
def unk_token(self, value): def unk_token(self, value):
self.add_tokens([value])
self._unk_token = value self._unk_token = value
@sep_token.setter @sep_token.setter
def sep_token(self, value): def sep_token(self, value):
self.add_tokens([value])
self._sep_token = value self._sep_token = value
@pad_token.setter @pad_token.setter
def pad_token(self, value): def pad_token(self, value):
self.add_tokens([value])
self._pad_token = value self._pad_token = value
@cls_token.setter @cls_token.setter
def cls_token(self, value): def cls_token(self, value):
self.add_tokens([value])
self._cls_token = value self._cls_token = value
@mask_token.setter @mask_token.setter
def mask_token(self, value): def mask_token(self, value):
self.add_tokens([value])
self._mask_token = value self._mask_token = value
@additional_special_tokens.setter @additional_special_tokens.setter
def additional_special_tokens(self, value): def additional_special_tokens(self, value):
self.add_tokens(value)
self._additional_special_tokens = value self._additional_special_tokens = value
def __init__(self, max_len=None, **kwargs): def __init__(self, max_len=None, **kwargs):
...@@ -179,6 +171,10 @@ class PreTrainedTokenizer(object): ...@@ -179,6 +171,10 @@ class PreTrainedTokenizer(object):
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
setattr(self, key, value) setattr(self, key, value)
...@@ -415,15 +411,39 @@ class PreTrainedTokenizer(object): ...@@ -415,15 +411,39 @@ class PreTrainedTokenizer(object):
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns:
Number of tokens added to the vocabulary.
Examples::
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
special_tokens_dict = {'cls_token': '<CLS>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
assert tokenizer.cls_token == '<CLS>'
""" """
if not special_tokens_dict: if not special_tokens_dict:
return 0 return 0
added_tokens = 0
for key, value in special_tokens_dict.items(): for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
added_tokens += self.add_tokens(value)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
added_tokens += self.add_tokens([value])
logger.info("Assigning %s to the %s key of the tokenizer", value, key) logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value) setattr(self, key, value)
return added_tokens
def tokenize(self, text, **kwargs): def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer. """ Converts a string in a sequence of tokens (string), using the tokenizer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment