Commit 634a3172 authored by LysandreJik's avatar LysandreJik
Browse files

Added integration tests for sequence builders.

parent 22ac004a
...@@ -125,6 +125,17 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -125,6 +125,17 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertFalse(_is_punctuation(u"A")) self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" ")) self.assertFalse(_is_punctuation(u" "))
def test_sequence_builders(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -71,10 +71,22 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -71,10 +71,22 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
[0, 31414, 232, 328, 2] [0, 31414, 232, 328, 2]
) )
self.assertListEqual( self.assertListEqual(
tokenizer.encode('Hello world! cécé herlolip'), tokenizer.encode('Hello world! cécé herlolip 418'),
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
) )
def test_sequence_builders(self):
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == [0] + text + [2]
assert encoded_pair == [0] + text + [2, 2] + text_2 + [2]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -66,6 +66,17 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -66,6 +66,17 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_sequence_builders(self):
tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == [1] + text + [1]
assert encoded_pair == [1] + text + [1] + text_2 + [1]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -89,6 +89,18 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -89,6 +89,18 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
def test_sequence_builders(self):
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == text + [4, 3]
assert encoded_pair == text + [4] + text_2 + [4, 3]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment