Commit 64e0adda authored by thomwolf's avatar thomwolf
Browse files

better error message

parent 382e2d1e
...@@ -658,7 +658,6 @@ class BertPreTrainedModel(nn.Module): ...@@ -658,7 +658,6 @@ class BertPreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error( logger.error(
...@@ -673,6 +672,22 @@ class BertPreTrainedModel(nn.Module): ...@@ -673,6 +672,22 @@ class BertPreTrainedModel(nn.Module):
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file)) archive_file))
return None return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file: if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file)) logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file)) logger.info("loading configuration file {}".format(config_file))
......
...@@ -493,7 +493,6 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -493,7 +493,6 @@ class GPT2PreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error( logger.error(
...@@ -502,10 +501,27 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -502,10 +501,27 @@ class GPT2PreTrainedModel(nn.Module):
else: else:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} " "We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format( "at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file, config_file archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
) )
) )
return None return None
......
...@@ -496,7 +496,6 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -496,7 +496,6 @@ class OpenAIGPTPreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error( logger.error(
...@@ -505,10 +504,27 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -505,10 +504,27 @@ class OpenAIGPTPreTrainedModel(nn.Module):
else: else:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} " "We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format( "at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file, config_file archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
) )
) )
return None return None
......
...@@ -921,7 +921,6 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -921,7 +921,6 @@ class TransfoXLPreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error( logger.error(
...@@ -930,12 +929,29 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -930,12 +929,29 @@ class TransfoXLPreTrainedModel(nn.Module):
else: else:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} " "We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format( "at this path or url.".format(
pretrained_model_name_or_path, pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), config_file
pretrained_model_name_or_path, )
archive_file, config_file)) )
return None return None
if resolved_archive_file == archive_file and resolved_config_file == config_file: if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file)) logger.info("loading weights file {}".format(archive_file))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment