Unverified Commit 4132a028 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #29 from huggingface/first-release

First release
parents 02173a1a 47a7d4ec
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import torch import torch
import optimization from pytorch_pretrained_bert import BertAdam
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
...@@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase): ...@@ -34,7 +34,7 @@ class OptimizationTest(unittest.TestCase):
target = torch.tensor([0.4, 0.2, -0.5]) target = torch.tensor([0.4, 0.2, -0.5])
criterion = torch.nn.MSELoss(reduction='elementwise_mean') criterion = torch.nn.MSELoss(reduction='elementwise_mean')
# No warmup, constant schedule, no gradient clipping # No warmup, constant schedule, no gradient clipping
optimizer = optimization.BERTAdam(params=[w], lr=2e-1, optimizer = BertAdam(params=[w], lr=2e-1,
weight_decay_rate=0.0, weight_decay_rate=0.0,
max_grad_norm=-1) max_grad_norm=-1)
for _ in range(100): for _ in range(100):
......
...@@ -19,7 +19,8 @@ from __future__ import print_function ...@@ -19,7 +19,8 @@ from __future__ import print_function
import os import os
import unittest import unittest
import tokenization from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer,
_is_whitespace, _is_control, _is_punctuation)
class TokenizationTest(unittest.TestCase): class TokenizationTest(unittest.TestCase):
...@@ -34,7 +35,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -34,7 +35,7 @@ class TokenizationTest(unittest.TestCase):
vocab_file = vocab_writer.name vocab_file = vocab_writer.name
tokenizer = tokenization.FullTokenizer(vocab_file) tokenizer = BertTokenizer(vocab_file)
os.remove(vocab_file) os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
...@@ -44,14 +45,14 @@ class TokenizationTest(unittest.TestCase): ...@@ -44,14 +45,14 @@ class TokenizationTest(unittest.TestCase):
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self): def test_chinese(self):
tokenizer = tokenization.BasicTokenizer() tokenizer = BasicTokenizer()
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u"ah\u535A\u63A8zz"), tokenizer.tokenize(u"ah\u535A\u63A8zz"),
[u"ah", u"\u535A", u"\u63A8", u"zz"]) [u"ah", u"\u535A", u"\u63A8", u"zz"])
def test_basic_tokenizer_lower(self): def test_basic_tokenizer_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=True) tokenizer = BasicTokenizer(do_lower_case=True)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
...@@ -59,7 +60,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -59,7 +60,7 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_basic_tokenizer_no_lower(self): def test_basic_tokenizer_no_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=False) tokenizer = BasicTokenizer(do_lower_case=False)
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
...@@ -74,7 +75,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -74,7 +75,7 @@ class TokenizationTest(unittest.TestCase):
vocab = {} vocab = {}
for (i, token) in enumerate(vocab_tokens): for (i, token) in enumerate(vocab_tokens):
vocab[token] = i vocab[token] = i
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) tokenizer = WordpieceTokenizer(vocab=vocab)
self.assertListEqual(tokenizer.tokenize(""), []) self.assertListEqual(tokenizer.tokenize(""), [])
...@@ -85,46 +86,32 @@ class TokenizationTest(unittest.TestCase): ...@@ -85,46 +86,32 @@ class TokenizationTest(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
def test_convert_tokens_to_ids(self):
vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing"
]
vocab = {}
for (i, token) in enumerate(vocab_tokens):
vocab[token] = i
self.assertListEqual(
tokenization.convert_tokens_to_ids(
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
def test_is_whitespace(self): def test_is_whitespace(self):
self.assertTrue(tokenization._is_whitespace(u" ")) self.assertTrue(_is_whitespace(u" "))
self.assertTrue(tokenization._is_whitespace(u"\t")) self.assertTrue(_is_whitespace(u"\t"))
self.assertTrue(tokenization._is_whitespace(u"\r")) self.assertTrue(_is_whitespace(u"\r"))
self.assertTrue(tokenization._is_whitespace(u"\n")) self.assertTrue(_is_whitespace(u"\n"))
self.assertTrue(tokenization._is_whitespace(u"\u00A0")) self.assertTrue(_is_whitespace(u"\u00A0"))
self.assertFalse(tokenization._is_whitespace(u"A")) self.assertFalse(_is_whitespace(u"A"))
self.assertFalse(tokenization._is_whitespace(u"-")) self.assertFalse(_is_whitespace(u"-"))
def test_is_control(self): def test_is_control(self):
self.assertTrue(tokenization._is_control(u"\u0005")) self.assertTrue(_is_control(u"\u0005"))
self.assertFalse(tokenization._is_control(u"A")) self.assertFalse(_is_control(u"A"))
self.assertFalse(tokenization._is_control(u" ")) self.assertFalse(_is_control(u" "))
self.assertFalse(tokenization._is_control(u"\t")) self.assertFalse(_is_control(u"\t"))
self.assertFalse(tokenization._is_control(u"\r")) self.assertFalse(_is_control(u"\r"))
def test_is_punctuation(self): def test_is_punctuation(self):
self.assertTrue(tokenization._is_punctuation(u"-")) self.assertTrue(_is_punctuation(u"-"))
self.assertTrue(tokenization._is_punctuation(u"$")) self.assertTrue(_is_punctuation(u"$"))
self.assertTrue(tokenization._is_punctuation(u"`")) self.assertTrue(_is_punctuation(u"`"))
self.assertTrue(tokenization._is_punctuation(u".")) self.assertTrue(_is_punctuation(u"."))
self.assertFalse(tokenization._is_punctuation(u"A")) self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(tokenization._is_punctuation(u" ")) self.assertFalse(_is_punctuation(u" "))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment