Commit 328afb70 authored by thomwolf's avatar thomwolf
Browse files

cleaning up tokenizer tests structure (at last) - last remaining ppb refs

parent 00132b7a
...@@ -345,8 +345,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/') ...@@ -345,8 +345,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/')
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules ### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer. The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences:
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API.
- it only implements weights decay correction,
- schedules are now externals (see below),
- gradient clipping is now also external (see below).
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping.
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore. The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
...@@ -355,6 +360,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch ...@@ -355,6 +360,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
```python ```python
# Parameters: # Parameters:
lr = 1e-3 lr = 1e-3
max_grad_norm = 1.0
num_total_steps = 1000 num_total_steps = 1000
num_warmup_steps = 100 num_warmup_steps = 100
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
...@@ -374,6 +380,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot ...@@ -374,6 +380,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot
for batch in train_data: for batch in train_data:
loss = model(batch) loss = model(batch)
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
scheduler.step() scheduler.step()
optimizer.step() optimizer.step()
``` ```
......
...@@ -68,8 +68,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/') ...@@ -68,8 +68,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/')
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules ### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer. The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences:
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API.
- it only implements weights decay correction,
- schedules are now externals (see below),
- gradient clipping is now also external (see below).
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping.
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore. The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
...@@ -78,6 +83,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch ...@@ -78,6 +83,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
```python ```python
# Parameters: # Parameters:
lr = 1e-3 lr = 1e-3
max_grad_norm = 1.0
num_total_steps = 1000 num_total_steps = 1000
num_warmup_steps = 100 num_warmup_steps = 100
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
...@@ -97,6 +103,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot ...@@ -97,6 +103,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot
for batch in train_data: for batch in train_data:
loss = model(batch) loss = model(batch)
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
scheduler.step() scheduler.step()
optimizer.step() optimizer.step()
``` ```
...@@ -122,7 +122,7 @@ Here is the recommended way of saving the model, configuration and vocabulary to ...@@ -122,7 +122,7 @@ Here is the recommended way of saving the model, configuration and vocabulary to
.. code-block:: python .. code-block:: python
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
output_dir = "./models/" output_dir = "./models/"
......
...@@ -74,7 +74,7 @@ according to a ``BertConfig`` class and then saved to disk under the filename `` ...@@ -74,7 +74,7 @@ according to a ``BertConfig`` class and then saved to disk under the filename ``
.. code-block:: python .. code-block:: python
from pytorch_pretrained_bert import BertModel, BertTokenizer, BertConfig from pytorch_transformers import BertModel, BertTokenizer, BertConfig
import torch import torch
enc = BertTokenizer.from_pretrained("bert-base-uncased") enc = BertTokenizer.from_pretrained("bert-base-uncased")
...@@ -105,6 +105,9 @@ according to a ``BertConfig`` class and then saved to disk under the filename `` ...@@ -105,6 +105,9 @@ according to a ``BertConfig`` class and then saved to disk under the filename ``
# The model needs to be in evaluation mode # The model needs to be in evaluation mode
model.eval() model.eval()
# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# Creating the trace # Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt") torch.jit.save(traced_model, "traced_bert.pt")
......
...@@ -39,4 +39,4 @@ from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, ...@@ -39,4 +39,4 @@ from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
...@@ -20,7 +20,7 @@ import argparse ...@@ -20,7 +20,7 @@ import argparse
import torch import torch
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from pytorch_pretrained_bert.modeling import BertModel from pytorch_transformers.modeling import BertModel
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
......
...@@ -38,10 +38,13 @@ except ImportError: ...@@ -38,10 +38,13 @@ except ImportError:
try: try:
from pathlib import Path from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path( PYTORCH_PRETRAINED_BERT_CACHE = Path(
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
except (AttributeError, ImportError): except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
default_cache_path) os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
default_cache_path))
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
...@@ -70,7 +73,7 @@ def filename_to_url(filename, cache_dir=None): ...@@ -70,7 +73,7 @@ def filename_to_url(filename, cache_dir=None):
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
...@@ -98,7 +101,7 @@ def cached_path(url_or_filename, cache_dir=None): ...@@ -98,7 +101,7 @@ def cached_path(url_or_filename, cache_dir=None):
make sure the file exists and then return the path. make sure the file exists and then return the path.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename) url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
...@@ -187,7 +190,7 @@ def get_from_cache(url, cache_dir=None): ...@@ -187,7 +190,7 @@ def get_from_cache(url, cache_dir=None):
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if sys.version_info[0] == 2 and not isinstance(cache_dir, str): if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
......
...@@ -24,26 +24,33 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer, ...@@ -24,26 +24,33 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
_is_whitespace, VOCAB_FILES_NAMES) _is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
class TokenizationTest(unittest.TestCase): class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = BertTokenizer
def setUp(self):
super(BertTokenizationTest, self).setUp()
def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", ",", "low", "lowest", "##ing", ",", "low", "lowest",
] ]
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self):
return BertTokenizer.from_pretrained(self.tmpdirname)
def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running" input_text = u"UNwant\u00E9d,running"
output_text = u"unwanted, running" output_text = u"unwanted, running"
return input_text, output_text
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname) def test_full_tokenizer(self):
tokenizer = BertTokenizer(self.vocab_file)
tokenizer = BertTokenizer(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
......
...@@ -20,33 +20,40 @@ import json ...@@ -20,33 +20,40 @@ import json
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
class GPT2TokenizationTest(unittest.TestCase): class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_full_tokenizer(self): tokenizer_class = GPT2Tokenizer
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
def setUp(self):
super(GPT2TokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er", "lo", "low", "er",
"low", "lowest", "newer", "wider", "<unk>"] "low", "lowest", "newer", "wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""] merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
special_tokens_map = {"unk_token": "<unk>"} self.special_tokens_map = {"unk_token": "<unk>"}
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(self.vocab_file, "w") as fp:
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens)) fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self):
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u"lower<unk>newer" output_text = u"lower<unk>newer"
return input_text, output_text
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map) def test_full_tokenizer(self):
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
text = "lower" text = "lower"
bpe_tokens = ["low", "er"] bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
......
...@@ -20,13 +20,17 @@ import json ...@@ -20,13 +20,17 @@ import json
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
class OpenAIGPTTokenizationTest(unittest.TestCase): class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_full_tokenizer(self): tokenizer_class = OpenAIGPTTokenizer
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
def setUp(self):
super(OpenAIGPTTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>", "w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>", "lo", "low", "er</w>",
...@@ -34,20 +38,24 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -34,20 +38,24 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""] merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(self.vocab_file, "w") as fp:
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens)) fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname)
def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u"lower newer" output_text = u"lower newer"
return input_text, output_text
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file) def test_full_tokenizer(self):
tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
text = "lower" text = "lower"
bpe_tokens = ["low", "er</w>"] bpe_tokens = ["low", "er</w>"]
......
...@@ -19,6 +19,7 @@ import sys ...@@ -19,6 +19,7 @@ import sys
from io import open from io import open
import tempfile import tempfile
import shutil import shutil
import unittest
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
...@@ -36,8 +37,26 @@ else: ...@@ -36,8 +37,26 @@ else:
unicode = str unicode = str
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs): class CommonTestCases:
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
class CommonTokenizerTester(unittest.TestCase):
tokenizer_class = None
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def get_tokenizer(self):
raise NotImplementedError
def get_input_output_texts(self):
raise NotImplementedError
def test_save_and_load_tokenizer(self):
tokenizer = self.get_tokenizer()
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
...@@ -46,11 +65,11 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, * ...@@ -46,11 +65,11 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, *
tokenizer = tokenizer.from_pretrained(tmpdirname) tokenizer = tokenizer.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
tester.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs): def test_pickle_tokenizer(self):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) tokenizer = self.get_tokenizer()
tester.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
text = u"Munich and Berlin are nice cities" text = u"Munich and Berlin are nice cities"
subwords = tokenizer.tokenize(text) subwords = tokenizer.tokenize(text)
...@@ -64,32 +83,32 @@ def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs ...@@ -64,32 +83,32 @@ def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs
subwords_loaded = tokenizer_new.tokenize(text) subwords_loaded = tokenizer_new.tokenize(text)
tester.assertListEqual(subwords, subwords_loaded) self.assertListEqual(subwords, subwords_loaded)
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs): def test_add_tokens_tokenizer(self):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) tokenizer = self.get_tokenizer()
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
all_size = len(tokenizer) all_size = len(tokenizer)
tester.assertNotEqual(vocab_size, 0) self.assertNotEqual(vocab_size, 0)
tester.assertEqual(vocab_size, all_size) self.assertEqual(vocab_size, all_size)
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks) added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer) all_size_2 = len(tokenizer)
tester.assertNotEqual(vocab_size_2, 0) self.assertNotEqual(vocab_size_2, 0)
tester.assertEqual(vocab_size, vocab_size_2) self.assertEqual(vocab_size, vocab_size_2)
tester.assertEqual(added_toks, len(new_toks)) self.assertEqual(added_toks, len(new_toks))
tester.assertEqual(all_size_2, all_size + len(new_toks)) self.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
tester.assertGreaterEqual(len(tokens), 4) self.assertGreaterEqual(len(tokens), 4)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
'pad_token': "<<<<<|||>|>>>>|>"} 'pad_token': "<<<<<|||>|>>>>|>"}
...@@ -97,52 +116,45 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw ...@@ -97,52 +116,45 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw
vocab_size_3 = tokenizer.vocab_size vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer) all_size_3 = len(tokenizer)
tester.assertNotEqual(vocab_size_3, 0) self.assertNotEqual(vocab_size_3, 0)
tester.assertEqual(vocab_size, vocab_size_3) self.assertEqual(vocab_size, vocab_size_3)
tester.assertEqual(added_toks_2, len(new_toks_2)) self.assertEqual(added_toks_2, len(new_toks_2))
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
tester.assertGreaterEqual(len(tokens), 6) self.assertGreaterEqual(len(tokens), 6)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[0], tokens[1]) self.assertGreater(tokens[0], tokens[1])
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokens[-3]) self.assertGreater(tokens[-2], tokens[-3])
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token)) self.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) self.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): def test_required_methods_tokenizer(self):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) tokenizer = self.get_tokenizer()
input_text, output_text = self.get_input_output_texts()
tokens = tokenizer.tokenize(input_text) tokens = tokenizer.tokenize(input_text)
ids = tokenizer.convert_tokens_to_ids(tokens) ids = tokenizer.convert_tokens_to_ids(tokens)
ids_2 = tokenizer.encode(input_text) ids_2 = tokenizer.encode(input_text)
tester.assertListEqual(ids, ids_2) self.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids) tokens_2 = tokenizer.convert_ids_to_tokens(ids)
text_2 = tokenizer.decode(ids) text_2 = tokenizer.decode(ids)
tester.assertEqual(text_2, output_text) self.assertEqual(text_2, output_text)
tester.assertNotEqual(len(tokens_2), 0) self.assertNotEqual(len(tokens_2), 0)
tester.assertIsInstance(text_2, (str, unicode)) self.assertIsInstance(text_2, (str, unicode))
def create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): def test_pretrained_model_lists(self):
weights_list = list(tokenizer_class.max_model_input_sizes.keys()) weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
weights_lists_2 = [] weights_lists_2 = []
for file_id, map_list in tokenizer_class.pretrained_vocab_files_map.items(): for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items():
weights_lists_2.append(list(map_list.keys())) weights_lists_2.append(list(map_list.keys()))
for weights_list_2 in weights_lists_2: for weights_list_2 in weights_lists_2:
tester.assertListEqual(weights_list, weights_list_2) self.assertListEqual(weights_list, weights_list_2)
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
...@@ -20,26 +20,33 @@ from io import open ...@@ -20,26 +20,33 @@ from io import open
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from.tokenization_tests_commons import CommonTestCases
class TransfoXLTokenizationTest(unittest.TestCase): class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = TransfoXLTokenizer
def setUp(self):
super(TransfoXLTokenizationTest, self).setUp()
def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
"running", ",", "low", "l", "running", ",", "low", "l",
] ]
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self):
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True)
def get_input_output_texts(self):
input_text = u"<unk> UNwanted , running" input_text = u"<unk> UNwanted , running"
output_text = u"<unk> unwanted, running" output_text = u"<unk> unwanted, running"
return input_text, output_text
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True) def test_full_tokenizer(self):
tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True)
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running") tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"]) self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
......
...@@ -20,12 +20,16 @@ import json ...@@ -20,12 +20,16 @@ import json
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
class XLMTokenizationTest(unittest.TestCase): class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_full_tokenizer(self): tokenizer_class = XLMTokenizer
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
def setUp(self):
super(XLMTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>", "w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>", "lo", "low", "er</w>",
...@@ -33,20 +37,24 @@ class XLMTokenizationTest(unittest.TestCase): ...@@ -33,20 +37,24 @@ class XLMTokenizationTest(unittest.TestCase):
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""] merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
with TemporaryDirectory() as tmpdirname: self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(self.vocab_file, "w") as fp:
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens)) fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self):
return XLMTokenizer.from_pretrained(self.tmpdirname)
def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
output_text = u"lower newer" output_text = u"lower newer"
return input_text, output_text
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname) def test_full_tokenizer(self):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
tokenizer = XLMTokenizer(vocab_file, merges_file) tokenizer = XLMTokenizer(self.vocab_file, self.merges_file)
text = "lower" text = "lower"
bpe_tokens = ["low", "er</w>"] bpe_tokens = ["low", "er</w>"]
......
...@@ -19,23 +19,33 @@ import unittest ...@@ -19,23 +19,33 @@ import unittest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory from .tokenization_tests_commons import CommonTestCases
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model') 'fixtures/test_sentencepiece.model')
class XLNetTokenizationTest(unittest.TestCase): class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
def test_full_tokenizer(self): tokenizer_class = XLNetTokenizer
def setUp(self):
super(XLNetTokenizationTest, self).setUp()
# We have a SentencePiece fixture for testing
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
with TemporaryDirectory() as tmpdirname: def get_tokenizer(self):
tokenizer.save_pretrained(tmpdirname) return XLNetTokenizer.from_pretrained(self.tmpdirname)
def get_input_output_texts(self):
input_text = u"This is a test" input_text = u"This is a test"
output_text = u"This is a test" output_text = u"This is a test"
return input_text, output_text
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize(u'This is a test') tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
......
...@@ -86,7 +86,7 @@ def whitespace_tokenize(text): ...@@ -86,7 +86,7 @@ def whitespace_tokenize(text):
class BertTokenizer(PreTrainedTokenizer): class BertTokenizer(PreTrainedTokenizer):
r""" r"""
Constructs a BertTokenizer. Constructs a BertTokenizer.
:class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece :class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
Args: Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file vocab_file: Path to a one-wordpiece-per-line vocabulary file
......
...@@ -125,42 +125,34 @@ class PreTrainedTokenizer(object): ...@@ -125,42 +125,34 @@ class PreTrainedTokenizer(object):
@bos_token.setter @bos_token.setter
def bos_token(self, value): def bos_token(self, value):
self.add_tokens([value])
self._bos_token = value self._bos_token = value
@eos_token.setter @eos_token.setter
def eos_token(self, value): def eos_token(self, value):
self.add_tokens([value])
self._eos_token = value self._eos_token = value
@unk_token.setter @unk_token.setter
def unk_token(self, value): def unk_token(self, value):
self.add_tokens([value])
self._unk_token = value self._unk_token = value
@sep_token.setter @sep_token.setter
def sep_token(self, value): def sep_token(self, value):
self.add_tokens([value])
self._sep_token = value self._sep_token = value
@pad_token.setter @pad_token.setter
def pad_token(self, value): def pad_token(self, value):
self.add_tokens([value])
self._pad_token = value self._pad_token = value
@cls_token.setter @cls_token.setter
def cls_token(self, value): def cls_token(self, value):
self.add_tokens([value])
self._cls_token = value self._cls_token = value
@mask_token.setter @mask_token.setter
def mask_token(self, value): def mask_token(self, value):
self.add_tokens([value])
self._mask_token = value self._mask_token = value
@additional_special_tokens.setter @additional_special_tokens.setter
def additional_special_tokens(self, value): def additional_special_tokens(self, value):
self.add_tokens(value)
self._additional_special_tokens = value self._additional_special_tokens = value
def __init__(self, max_len=None, **kwargs): def __init__(self, max_len=None, **kwargs):
...@@ -179,6 +171,10 @@ class PreTrainedTokenizer(object): ...@@ -179,6 +171,10 @@ class PreTrainedTokenizer(object):
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
setattr(self, key, value) setattr(self, key, value)
...@@ -415,15 +411,39 @@ class PreTrainedTokenizer(object): ...@@ -415,15 +411,39 @@ class PreTrainedTokenizer(object):
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns:
Number of tokens added to the vocabulary.
Examples::
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
special_tokens_dict = {'cls_token': '<CLS>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
assert tokenizer.cls_token == '<CLS>'
""" """
if not special_tokens_dict: if not special_tokens_dict:
return 0 return 0
added_tokens = 0
for key, value in special_tokens_dict.items(): for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if key == 'additional_special_tokens':
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
added_tokens += self.add_tokens(value)
else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
added_tokens += self.add_tokens([value])
logger.info("Assigning %s to the %s key of the tokenizer", value, key) logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value) setattr(self, key, value)
return added_tokens
def tokenize(self, text, **kwargs): def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer. """ Converts a string in a sequence of tokens (string), using the tokenizer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment