Commit 14eef67e authored by Abhishek Rao's avatar Abhishek Rao
Browse files

Fix at config rather than model

parent 296df2b1
...@@ -166,7 +166,7 @@ class PretrainedConfig(object): ...@@ -166,7 +166,7 @@ class PretrainedConfig(object):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError: except EnvironmentError as e:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error( logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format( "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
...@@ -179,7 +179,7 @@ class PretrainedConfig(object): ...@@ -179,7 +179,7 @@ class PretrainedConfig(object):
pretrained_model_name_or_path, pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()), ', '.join(cls.pretrained_config_archive_map.keys()),
config_file)) config_file))
return None raise e
if resolved_config_file == config_file: if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file)) logger.info("loading configuration file {}".format(config_file))
else: else:
...@@ -473,7 +473,7 @@ class PreTrainedModel(nn.Module): ...@@ -473,7 +473,7 @@ class PreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError as e: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error( logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format( "Couldn't reach server at '{}' to download pretrained weights.".format(
...@@ -486,7 +486,7 @@ class PreTrainedModel(nn.Module): ...@@ -486,7 +486,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_name_or_path, pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()), ', '.join(cls.pretrained_model_archive_map.keys()),
archive_file)) archive_file))
raise e return None
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file)) logger.info("loading weights file {}".format(archive_file))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment