Unverified Commit f394a2a5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Json configs] Make json prettier for all saved tokenizer files & ensure same...

[Json configs] Make json prettier for all saved tokenizer files & ensure same json format for all processors (tok + feat_extract) (#17457)

* [Json dump] Make json prettier

* correct more tokenizeirs

* more patterns

* add aggressive test

* the aggressive test was actually useful :-)

* more tests

* Apply suggestions from code review
parent 6ee1474b
......@@ -354,6 +354,8 @@ class FeatureExtractionMixin(PushToHubMixin):
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Feature extractor pushed to the hub in this commit: {url}")
return [output_feature_extractor_file]
@classmethod
def get_feature_extractor_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
......
......@@ -318,7 +318,7 @@ class BartTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -220,7 +220,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -345,7 +345,7 @@ class CLIPTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -236,7 +236,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -505,11 +505,11 @@ class FSMTTokenizer(PreTrainedTokenizer):
)
with open(src_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
with open(tgt_vocab_file, "w", encoding="utf-8") as f:
tgt_vocab = {v: k for k, v in self.decoder.items()}
f.write(json.dumps(tgt_vocab, ensure_ascii=False))
f.write(json.dumps(tgt_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merges_file, "w", encoding="utf-8") as writer:
......
......@@ -297,7 +297,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -437,7 +437,7 @@ class LayoutLMv3Tokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -1386,6 +1386,6 @@ class LukeTokenizer(RobertaTokenizer):
)
with open(entity_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.entity_vocab, ensure_ascii=False))
f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return vocab_file, merge_file, entity_vocab_file
......@@ -1509,7 +1509,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
)
with open(entity_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.entity_vocab, ensure_ascii=False))
f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return out_vocab_file, entity_vocab_file
......
......@@ -215,7 +215,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -324,7 +324,7 @@ class RobertaTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -250,7 +250,7 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
if self.bpe_ranks is None:
......
......@@ -503,7 +503,7 @@ class TapexTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -603,7 +603,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
......@@ -922,6 +922,6 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
......@@ -568,7 +568,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
......
......@@ -965,7 +965,7 @@ class XLMTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
......
......@@ -1494,3 +1494,20 @@ def nested_simplify(obj, decimals=3):
return nested_simplify(obj.item(), decimals)
else:
raise Exception(f"Not supported: {type(obj)}")
def check_json_file_has_correct_format(file_path):
with open(file_path, "r") as f:
lines = f.readlines()
if len(lines) == 1:
# length can only be 1 if dict is empty
assert lines[0] == "{}"
else:
# otherwise make sure json has correct format (at least 3 lines)
assert len(lines) >= 3
# each key one line, ident should be 2, min length is 3
assert lines[0].strip() == "{"
for line in lines[1:-1]:
left_indent = len(lines[1]) - len(lines[1].lstrip())
assert left_indent == 2
assert lines[-1].strip() == "}"
......@@ -2118,13 +2118,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
custom_object_save(self, save_directory, config=tokenizer_config)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
# Sanitize AddedTokens in special_tokens_map
write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
f.write(json.dumps(write_dict, ensure_ascii=False))
out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
logger.info(f"Special tokens file saved in {special_tokens_map_file}")
file_names = (tokenizer_config_file, special_tokens_map_file)
......@@ -2168,7 +2170,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
out_str = json.dumps(added_vocab, ensure_ascii=False)
out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
logger.info(f"added tokens file saved in {added_tokens_file}")
......
......@@ -585,7 +585,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
out_str = json.dumps(added_vocab, ensure_ascii=False)
out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment