"examples/trials/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "b8d19e4531f1955b0c256f0887494d8eb96f42ff"
Commit f5397ffc authored by thomwolf's avatar thomwolf
Browse files

update loading logics

parent 271f2136
...@@ -163,7 +163,7 @@ if _tf_available and _torch_available: ...@@ -163,7 +163,7 @@ if _tf_available and _torch_available:
# Files and general utilities # Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings, cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME) WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
......
...@@ -49,6 +49,7 @@ except (AttributeError, ImportError): ...@@ -49,6 +49,7 @@ except (AttributeError, ImportError):
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
......
...@@ -24,7 +24,7 @@ import os ...@@ -24,7 +24,7 @@ import os
import tensorflow as tf import tensorflow as tf
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -205,38 +205,49 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -205,38 +205,49 @@ class TFPreTrainedModel(tf.keras.Model):
model_kwargs = kwargs model_kwargs = kwargs
# Load model # Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path is not None:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
if from_pt:
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
else:
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error( archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
"Couldn't reach server at '{}' to download pretrained weights.".format( elif os.path.isdir(pretrained_model_name_or_path):
archive_file)) if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
raise EnvironmentError("Error no file named {} found in directory {}".format(
tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME),
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else:
raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path))
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
raise e
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else: else:
logger.error( logger.info("loading weights file {} from cache at {}".format(
"Model name '{}' was not found in model name list ({}). " archive_file, resolved_archive_file))
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
return None
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else: else:
logger.info("loading weights file {} from cache at {}".format( resolved_archive_file = None
archive_file, resolved_archive_file))
# Instantiate model. # Instantiate model.
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
......
...@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss ...@@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn import functional as F from torch.nn import functional as F
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -294,11 +294,19 @@ class PreTrainedModel(nn.Module): ...@@ -294,11 +294,19 @@ class PreTrainedModel(nn.Module):
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path): elif os.path.isdir(pretrained_model_name_or_path):
if from_tf: if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
# Directly load from a TensorFlow checkpoint # Load from a TF 1.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
else: elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
raise EnvironmentError("Error no file named {} found in directory {}".format(
tuple(WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"),
pretrained_model_name_or_path))
elif os.path.isfile(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment