Commit 9d8fd2d4 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

tokenizer.save_pretrained: only save file if non-empty

parent 6e2c28a1
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Auto Model class. """ """ Auto Config class. """
import logging import logging
......
...@@ -513,12 +513,10 @@ class PreTrainedTokenizer(object): ...@@ -513,12 +513,10 @@ class PreTrainedTokenizer(object):
with open(special_tokens_map_file, "w", encoding="utf-8") as f: with open(special_tokens_map_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
with open(added_tokens_file, "w", encoding="utf-8") as f: if len(self.added_tokens_encoder) > 0:
if self.added_tokens_encoder: with open(added_tokens_file, "w", encoding="utf-8") as f:
out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False) out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
else: f.write(out_str)
out_str = "{}"
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory) vocab_files = self.save_vocabulary(save_directory)
......
...@@ -33,13 +33,13 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -33,13 +33,13 @@ class AutoTokenizerTest(unittest.TestCase):
# @slow # @slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in [x for x in BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys() if "japanese" not in x]: for model_name in (x for x in BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys() if "japanese" not in x):
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
self.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, BertTokenizer) self.assertIsInstance(tokenizer, BertTokenizer)
self.assertGreater(len(tokenizer), 0) self.assertGreater(len(tokenizer), 0)
for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: for model_name in GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys():
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
self.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, GPT2Tokenizer) self.assertIsInstance(tokenizer, GPT2Tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment