Commit f8276008 authored by thomwolf's avatar thomwolf
Browse files

update readme, file names, removing TF code, moving tests

parent 3c24e4be
...@@ -16,17 +16,19 @@ from __future__ import absolute_import ...@@ -16,17 +16,19 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import six
import unittest
import collections import collections
import json import json
import random import random
import re import re
from tensorflow_code import modeling import torch
import six
import tensorflow as tf import modeling as modeling
class BertModelTest(tf.test.TestCase): class BertModelTest(unittest.TestCase):
class BertModelTester(object): class BertModelTester(object):
def __init__(self, def __init__(self,
...@@ -68,18 +70,15 @@ class BertModelTest(tf.test.TestCase): ...@@ -68,18 +70,15 @@ class BertModelTest(tf.test.TestCase):
self.scope = scope self.scope = scope
def create_model(self): def create_model(self):
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
self.vocab_size)
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = BertModelTest.ids_tensor( input_mask = BertModelTest.ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
[self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
token_type_ids = BertModelTest.ids_tensor( token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
[self.batch_size, self.seq_length], self.type_vocab_size)
config = modeling.BertConfig( config = modeling.BertConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
...@@ -94,33 +93,23 @@ class BertModelTest(tf.test.TestCase): ...@@ -94,33 +93,23 @@ class BertModelTest(tf.test.TestCase):
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range) initializer_range=self.initializer_range)
model = modeling.BertModel( model = modeling.BertModel(config=config)
config=config,
is_training=self.is_training, all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=token_type_ids,
scope=self.scope)
outputs = { outputs = {
"embedding_output": model.get_embedding_output(), "sequence_output": all_encoder_layers[-1],
"sequence_output": model.get_sequence_output(), "pooled_output": pooled_output,
"pooled_output": model.get_pooled_output(), "all_encoder_layers": all_encoder_layers,
"all_encoder_layers": model.get_all_encoder_layers(),
} }
return outputs return outputs
def check_output(self, result): def check_output(self, result):
self.parent.assertAllEqual( self.parent.assertListEqual(
result["embedding_output"].shape, list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertAllEqual(
result["sequence_output"].shape,
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertAllEqual(result["pooled_output"].shape, self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
[self.batch_size, self.hidden_size])
def test_default(self): def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self)) self.run_tester(BertModelTest.BertModelTester(self))
...@@ -132,15 +121,11 @@ class BertModelTest(tf.test.TestCase): ...@@ -132,15 +121,11 @@ class BertModelTest(tf.test.TestCase):
self.assertEqual(obj["hidden_size"], 37) self.assertEqual(obj["hidden_size"], 37)
def run_tester(self, tester): def run_tester(self, tester):
with self.test_session() as sess: output_result = tester.create_model()
ops = tester.create_model()
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
output_result = sess.run(ops)
tester.check_output(output_result) tester.check_output(output_result)
self.assert_all_tensors_reachable(sess, [init_op, ops]) # TODO Find PyTorch equivalent of assert_all_tensors_reachable() if necessary
# self.assert_all_tensors_reachable(sess, [init_op, ops])
@classmethod @classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
...@@ -156,7 +141,8 @@ class BertModelTest(tf.test.TestCase): ...@@ -156,7 +141,8 @@ class BertModelTest(tf.test.TestCase):
for _ in range(total_dims): for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1)) values.append(rng.randint(0, vocab_size - 1))
return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) # TODO Solve : the returned tensors provoke index out of range errors when passed to the model
return torch.tensor(data=values, dtype=torch.int32)
def assert_all_tensors_reachable(self, sess, outputs): def assert_all_tensors_reachable(self, sess, outputs):
"""Checks that all the tensors in the graph are reachable from outputs.""" """Checks that all the tensors in the graph are reachable from outputs."""
...@@ -272,4 +258,4 @@ class BertModelTest(tf.test.TestCase): ...@@ -272,4 +258,4 @@ class BertModelTest(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() unittest.main()
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import torch import torch
import optimization_pytorch as optimization import optimization as optimization
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
......
...@@ -17,45 +17,44 @@ from __future__ import division ...@@ -17,45 +17,44 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import tempfile import unittest
from tensorflow_code import tokenization import tokenization as tokenization
import tensorflow as tf
class TokenizationTest(tf.test.TestCase): class TokenizationTest(unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", "," "##ing", ","
] ]
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name vocab_file = vocab_writer.name
tokenizer = tokenization.FullTokenizer(vocab_file) tokenizer = tokenization.FullTokenizer(vocab_file)
os.unlink(vocab_file) os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertAllEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_basic_tokenizer_lower(self): def test_basic_tokenizer_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=True) tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
self.assertAllEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello", "!", "how", "are", "you", "?"]) ["hello", "!", "how", "are", "you", "?"])
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_basic_tokenizer_no_lower(self): def test_basic_tokenizer_no_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=False) tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
self.assertAllEqual( self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"]) ["HeLLo", "!", "how", "Are", "yoU", "?"])
...@@ -70,13 +69,13 @@ class TokenizationTest(tf.test.TestCase): ...@@ -70,13 +69,13 @@ class TokenizationTest(tf.test.TestCase):
vocab[token] = i vocab[token] = i
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
self.assertAllEqual(tokenizer.tokenize(""), []) self.assertListEqual(tokenizer.tokenize(""), [])
self.assertAllEqual( self.assertListEqual(
tokenizer.tokenize("unwanted running"), tokenizer.tokenize("unwanted running"),
["un", "##want", "##ed", "runn", "##ing"]) ["un", "##want", "##ed", "runn", "##ing"])
self.assertAllEqual( self.assertListEqual(
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
def test_convert_tokens_to_ids(self): def test_convert_tokens_to_ids(self):
...@@ -89,7 +88,7 @@ class TokenizationTest(tf.test.TestCase): ...@@ -89,7 +88,7 @@ class TokenizationTest(tf.test.TestCase):
for (i, token) in enumerate(vocab_tokens): for (i, token) in enumerate(vocab_tokens):
vocab[token] = i vocab[token] = i
self.assertAllEqual( self.assertListEqual(
tokenization.convert_tokens_to_ids( tokenization.convert_tokens_to_ids(
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
...@@ -121,5 +120,5 @@ class TokenizationTest(tf.test.TestCase): ...@@ -121,5 +120,5 @@ class TokenizationTest(tf.test.TestCase):
self.assertFalse(tokenization._is_punctuation(u" ")) self.assertFalse(tokenization._is_punctuation(u" "))
if __name__ == "__main__": if __name__ == '__main__':
tf.test.main() unittest.main()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment