Commit 69da972a authored by thomwolf's avatar thomwolf
Browse files

added test and debug tokenizer configuration serialization

parent 88111de0
...@@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname) return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running" input_text = u"UNwant\u00E9d,running"
......
...@@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) kwargs.update(self.special_tokens_map)
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
......
...@@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname) return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
......
...@@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) kwargs.update(self.special_tokens_map)
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
......
...@@ -49,14 +49,19 @@ class CommonTestCases: ...@@ -49,14 +49,19 @@ class CommonTestCases:
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def get_input_output_texts(self): def get_input_output_texts(self):
raise NotImplementedError raise NotImplementedError
def test_save_and_load_tokenizer(self): def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
self.assertNotEqual(tokenizer.max_len, 42)
# Now let's start the test
tokenizer = self.get_tokenizer(max_len=42)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
...@@ -64,8 +69,12 @@ class CommonTestCases: ...@@ -64,8 +69,12 @@ class CommonTestCases:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname) tokenizer = tokenizer.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
self.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42)
tokenizer = tokenizer.from_pretrained(tmpdirname, max_len=43)
self.assertEqual(tokenizer.max_len, 43)
def test_pickle_tokenizer(self): def test_pickle_tokenizer(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
......
...@@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True) kwargs['lower_case'] = True
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"<unk> UNwanted , running" input_text = u"<unk> UNwanted , running"
......
...@@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return XLMTokenizer.from_pretrained(self.tmpdirname) return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"
......
...@@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return XLNetTokenizer.from_pretrained(self.tmpdirname) return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"This is a test" input_text = u"This is a test"
......
...@@ -332,7 +332,7 @@ class PreTrainedTokenizer(object): ...@@ -332,7 +332,7 @@ class PreTrainedTokenizer(object):
tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None) tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
if tokenizer_config_file is not None: if tokenizer_config_file is not None:
init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8")) init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
saved_init_inputs = init_kwargs.pop('init_inputs', []) saved_init_inputs = init_kwargs.pop('init_inputs', ())
if not init_inputs: if not init_inputs:
init_inputs = saved_init_inputs init_inputs = saved_init_inputs
else: else:
...@@ -399,6 +399,8 @@ class PreTrainedTokenizer(object): ...@@ -399,6 +399,8 @@ class PreTrainedTokenizer(object):
tokenizer_config = copy.deepcopy(self.init_kwargs) tokenizer_config = copy.deepcopy(self.init_kwargs)
tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs) tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys():
tokenizer_config.pop(file_id, None)
with open(tokenizer_config_file, 'w', encoding='utf-8') as f: with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False)) f.write(json.dumps(tokenizer_config, ensure_ascii=False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment