Commit 757750d6 authored by thomwolf's avatar thomwolf
Browse files

fix tests

parent a99b9717
...@@ -60,7 +60,7 @@ This package comprises the following classes that can be imported in Python and ...@@ -60,7 +60,7 @@ This package comprises the following classes that can be imported in Python and
- `BertTokenizer` - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization. - `BertTokenizer` - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
- One optimizer: - One optimizer:
- `BERTAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate. - `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
- A configuration class: - A configuration class:
- `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilisities to read and write from JSON configuration files. - `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilisities to read and write from JSON configuration files.
...@@ -155,7 +155,7 @@ Here is a detailed documentation of the classes in the package and how to use th ...@@ -155,7 +155,7 @@ Here is a detailed documentation of the classes in the package and how to use th
| [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance | | [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance |
| [PyTorch models](#PyTorch-models) | API of the six PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` | | [PyTorch models](#PyTorch-models) | API of the six PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` |
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class| | [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
| [Optimizer: `BERTAdam`](#Optimizer-BERTAdam) | API of the `BERTAdam` class | | [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class |
### Loading Google AI's pre-trained weigths and PyTorch dump ### Loading Google AI's pre-trained weigths and PyTorch dump
...@@ -294,12 +294,12 @@ and three methods: ...@@ -294,12 +294,12 @@ and three methods:
Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing. Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing.
### Optimizer: `BERTAdam` ### Optimizer: `BertAdam`
`BERTAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following: `BertAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following:
- BERTAdam implements weight decay fix, - BertAdam implements weight decay fix,
- BERTAdam doesn't compensate for bias as in the regular Adam optimizer. - BertAdam doesn't compensate for bias as in the regular Adam optimizer.
The optimizer accepts the following arguments: The optimizer accepts the following arguments:
......
...@@ -29,7 +29,7 @@ from torch.utils.data import TensorDataset, DataLoader, SequentialSampler ...@@ -29,7 +29,7 @@ from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import convert_to_unicode, BertTokenizer from pytorch_pretrained_bert.tokenization import convert_to_unicode, BertTokenizer
from pytorch_pretrained_bert.modeling import BertConfig, BertModel from pytorch_pretrained_bert.modeling import BertModel
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
......
...@@ -31,8 +31,8 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Sequentia ...@@ -31,8 +31,8 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Sequentia
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import printable_text, convert_to_unicode, BertTokenizer from pytorch_pretrained_bert.tokenization import printable_text, convert_to_unicode, BertTokenizer
from pytorch_pretrained_bert.modeling import BertConfig, BertForSequenceClassification from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BERTAdam from pytorch_pretrained_bert.optimization import BertAdam
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -512,7 +512,7 @@ def main(): ...@@ -512,7 +512,7 @@ def main():
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0} {'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
] ]
optimizer = BERTAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=num_train_steps)
......
...@@ -33,8 +33,8 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Sequentia ...@@ -33,8 +33,8 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Sequentia
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from pytorch_pretrained_bert.tokenization import printable_text, whitespace_tokenize, BasicTokenizer, BertTokenizer from pytorch_pretrained_bert.tokenization import printable_text, whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertConfig, BertForQuestionAnswering from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BERTAdam from pytorch_pretrained_bert.optimization import BertAdam
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -847,7 +847,7 @@ def main(): ...@@ -847,7 +847,7 @@ def main():
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0} {'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
] ]
optimizer = BERTAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=num_train_steps)
......
...@@ -2,4 +2,4 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer ...@@ -2,4 +2,4 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining, from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForQuestionAnswering) BertForSequenceClassification, BertForQuestionAnswering)
from .optimization import BERTAdam from .optimization import BertAdam
...@@ -41,7 +41,7 @@ SCHEDULES = { ...@@ -41,7 +41,7 @@ SCHEDULES = {
} }
class BERTAdam(Optimizer): class BertAdam(Optimizer):
"""Implements BERT version of Adam algorithm with weight decay fix. """Implements BERT version of Adam algorithm with weight decay fix.
Params: Params:
lr: learning rate lr: learning rate
...@@ -73,7 +73,7 @@ class BERTAdam(Optimizer): ...@@ -73,7 +73,7 @@ class BERTAdam(Optimizer):
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(BERTAdam, self).__init__(params, defaults) super(BertAdam, self).__init__(params, defaults)
def get_lr(self): def get_lr(self):
lr = [] lr = []
......
...@@ -22,7 +22,7 @@ import random ...@@ -22,7 +22,7 @@ import random
import torch import torch
import modeling from pytorch_pretrained_bert import BertConfig, BertModel
class BertModelTest(unittest.TestCase): class BertModelTest(unittest.TestCase):
...@@ -77,8 +77,8 @@ class BertModelTest(unittest.TestCase): ...@@ -77,8 +77,8 @@ class BertModelTest(unittest.TestCase):
if self.use_token_type_ids: if self.use_token_type_ids:
token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
config = modeling.BertConfig( config = BertConfig(
vocab_size=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
...@@ -90,7 +90,7 @@ class BertModelTest(unittest.TestCase): ...@@ -90,7 +90,7 @@ class BertModelTest(unittest.TestCase):
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range) initializer_range=self.initializer_range)
model = modeling.BertModel(config=config) model = BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
...@@ -112,7 +112,7 @@ class BertModelTest(unittest.TestCase): ...@@ -112,7 +112,7 @@ class BertModelTest(unittest.TestCase):
self.run_tester(BertModelTest.BertModelTester(self)) self.run_tester(BertModelTest.BertModelTester(self))
def test_config_to_json_string(self): def test_config_to_json_string(self):
config = modeling.BertConfig(vocab_size=99, hidden_size=37) config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
obj = json.loads(config.to_json_string()) obj = json.loads(config.to_json_string())
self.assertEqual(obj["vocab_size"], 99) self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["hidden_size"], 37) self.assertEqual(obj["hidden_size"], 37)
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import torch import torch
import optimization from pytorch_pretrained_bert import BertAdam
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
...@@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase): ...@@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase):
target = torch.tensor([0.4, 0.2, -0.5]) target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss(reduction='elementwise_mean') criterion = torch.nn.MSELoss(reduction='elementwise_mean')
# No warmup, constant schedule, no gradient clipping # No warmup, constant schedule, no gradient clipping
optimizer = optimization.BERTAdam(params=[w], lr=2e-1, optimizer = BertAdam(params=[w], lr=2e-1,
weight_decay_rate=0.0, weight_decay_rate=0.0,
max_grad_norm=-1) max_grad_norm=-1)
for _ in range(100): for _ in range(100):
......
...@@ -19,7 +19,8 @@ from __future__ import print_function ...@@ -19,7 +19,8 @@ from __future__ import print_function
import os import os
import unittest import unittest
import tokenization from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer,
_is_whitespace, _is_control, _is_punctuation)
class TokenizationTest(unittest.TestCase): class TokenizationTest(unittest.TestCase):
...@@ -34,7 +35,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -34,7 +35,7 @@ class TokenizationTest(unittest.TestCase):
vocab_file = vocab_writer.name vocab_file = vocab_writer.name
tokenizer = tokenization.BertTokenizer(vocab_file) tokenizer = BertTokenizer(vocab_file)
os.remove(vocab_file) os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
...@@ -44,14 +45,14 @@ class TokenizationTest(unittest.TestCase): ...@@ -44,14 +45,14 @@ class TokenizationTest(unittest.TestCase):
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self): def test_chinese(self):
tokenizer = tokenization.BasicTokenizer() tokenizer = BasicTokenizer()
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u"ah\u535A\u63A8zz"), tokenizer.tokenize(u"ah\u535A\u63A8zz"),
[u"ah", u"\u535A", u"\u63A8", u"zz"]) [u"ah", u"\u535A", u"\u63A8", u"zz"])
def test_basic_tokenizer_lower(self): def test_basic_tokenizer_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=True) tokenizer = BasicTokenizer(do_lower_case=True)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
...@@ -59,7 +60,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -59,7 +60,7 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_basic_tokenizer_no_lower(self): def test_basic_tokenizer_no_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=False) tokenizer = BasicTokenizer(do_lower_case=False)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
...@@ -74,7 +75,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -74,7 +75,7 @@ class TokenizationTest(unittest.TestCase):
vocab = {} vocab = {}
for (i, token) in enumerate(vocab_tokens): for (i, token) in enumerate(vocab_tokens):
vocab[token] = i vocab[token] = i
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) tokenizer = WordpieceTokenizer(vocab=vocab)
self.assertListEqual(tokenizer.tokenize(""), []) self.assertListEqual(tokenizer.tokenize(""), [])
...@@ -85,46 +86,32 @@ class TokenizationTest(unittest.TestCase): ...@@ -85,46 +86,32 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
def test_convert_tokens_to_ids(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
self.assertListEqual(
tokenization.convert_tokens_to_ids(
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
def test_is_whitespace(self): def test_is_whitespace(self):
self.assertTrue(tokenization._is_whitespace(u" ")) self.assertTrue(_is_whitespace(u" "))
self.assertTrue(tokenization._is_whitespace(u"\t")) self.assertTrue(_is_whitespace(u"\t"))
self.assertTrue(tokenization._is_whitespace(u"\r")) self.assertTrue(_is_whitespace(u"\r"))
self.assertTrue(tokenization._is_whitespace(u"\n")) self.assertTrue(_is_whitespace(u"\n"))
self.assertTrue(tokenization._is_whitespace(u"\u00A0")) self.assertTrue(_is_whitespace(u"\u00A0"))
self.assertFalse(tokenization._is_whitespace(u"A")) self.assertFalse(_is_whitespace(u"A"))
self.assertFalse(tokenization._is_whitespace(u"-")) self.assertFalse(_is_whitespace(u"-"))
def test_is_control(self): def test_is_control(self):
self.assertTrue(tokenization._is_control(u"\u0005")) self.assertTrue(_is_control(u"\u0005"))
self.assertFalse(tokenization._is_control(u"A")) self.assertFalse(_is_control(u"A"))
self.assertFalse(tokenization._is_control(u" ")) self.assertFalse(_is_control(u" "))
self.assertFalse(tokenization._is_control(u"\t")) self.assertFalse(_is_control(u"\t"))
self.assertFalse(tokenization._is_control(u"\r")) self.assertFalse(_is_control(u"\r"))
def test_is_punctuation(self): def test_is_punctuation(self):
self.assertTrue(tokenization._is_punctuation(u"-")) self.assertTrue(_is_punctuation(u"-"))
self.assertTrue(tokenization._is_punctuation(u"$")) self.assertTrue(_is_punctuation(u"$"))
self.assertTrue(tokenization._is_punctuation(u"`")) self.assertTrue(_is_punctuation(u"`"))
self.assertTrue(tokenization._is_punctuation(u".")) self.assertTrue(_is_punctuation(u"."))
self.assertFalse(tokenization._is_punctuation(u"A")) self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(tokenization._is_punctuation(u" ")) self.assertFalse(_is_punctuation(u" "))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment