Commit 1bbdbacd authored by thomwolf's avatar thomwolf
Browse files

update __init__ and saving

parent 031ad4eb
......@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# Files and general utilities
from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME, MODEL_CARD_NAME,
is_tf_available, is_torch_available)
from .data import (is_sklearn_available,
......
......@@ -67,14 +67,14 @@ class ModelCard(object):
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err
def save_pretrained(self, save_directory):
""" Save a model card object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.ModelCard.from_pretrained` class method.
def save_pretrained(self, save_directory_or_file):
""" Save a model card object to the directory or file `save_directory_or_file`.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model card can be saved"
# If we save using the predefined names, we can load using `from_pretrained`
output_model_card_file = os.path.join(save_directory, MODEL_CARD_NAME)
if os.path.isdir(save_directory_or_file):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
else:
output_model_card_file = save_directory_or_file
self.to_json_file(output_model_card_file)
logger.info("Model card saved in {}".format(output_model_card_file))
......@@ -139,8 +139,9 @@ class ModelCard(object):
model_card_file = pretrained_model_name_or_path
else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, postfix=MODEL_CARD_NAME)
# redirect to the cache, if necessary
try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
if resolved_model_card_file == model_card_file:
......@@ -163,6 +164,7 @@ class ModelCard(object):
', '.join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
model_card_file, MODEL_CARD_NAME))
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
model_card = cls()
......@@ -171,6 +173,7 @@ class ModelCard(object):
"model card file is not a valid JSON file. "
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file))
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
model_card = cls()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment