Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
031ad4eb
Commit
031ad4eb
authored
Dec 16, 2019
by
thomwolf
Browse files
improving JSON error messages (for model card and configurations)
parent
db0a9ee6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
8 deletions
+19
-8
transformers/configuration_utils.py
transformers/configuration_utils.py
+11
-4
transformers/model_card.py
transformers/model_card.py
+8
-4
No files found.
transformers/configuration_utils.py
View file @
031ad4eb
...
@@ -151,10 +151,14 @@ class PretrainedConfig(object):
...
@@ -151,10 +151,14 @@ class PretrainedConfig(object):
config_file
=
pretrained_model_name_or_path
config_file
=
pretrained_model_name_or_path
else
:
else
:
config_file
=
hf_bucket_url
(
pretrained_model_name_or_path
,
postfix
=
CONFIG_NAME
)
config_file
=
hf_bucket_url
(
pretrained_model_name_or_path
,
postfix
=
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
try
:
# Load from URL or cache if already cached
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
proxies
=
proxies
,
resume_download
=
resume_download
)
# Load config
config
=
cls
.
from_json_file
(
resolved_config_file
)
except
EnvironmentError
:
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
msg
=
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
...
@@ -168,15 +172,18 @@ class PretrainedConfig(object):
...
@@ -168,15 +172,18 @@ class PretrainedConfig(object):
config_file
,
CONFIG_NAME
)
config_file
,
CONFIG_NAME
)
raise
EnvironmentError
(
msg
)
raise
EnvironmentError
(
msg
)
except
json
.
JSONDecodeError
:
msg
=
"Couldn't reach server at '{}' to download configuration file or "
\
"configuration file is not a valid JSON file. "
\
"Please check network or file content here: {}."
.
format
(
config_file
,
resolved_config_file
)
raise
EnvironmentError
(
msg
)
if
resolved_config_file
==
config_file
:
if
resolved_config_file
==
config_file
:
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
else
:
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
config_file
,
resolved_config_file
))
# Load config
config
=
cls
.
from_json_file
(
resolved_config_file
)
if
hasattr
(
config
,
'pruned_heads'
):
if
hasattr
(
config
,
'pruned_heads'
):
config
.
pruned_heads
=
dict
((
int
(
key
),
value
)
for
key
,
value
in
config
.
pruned_heads
.
items
())
config
.
pruned_heads
=
dict
((
int
(
key
),
value
)
for
key
,
value
in
config
.
pruned_heads
.
items
())
...
...
transformers/model_card.py
View file @
031ad4eb
...
@@ -132,7 +132,7 @@ class ModelCard(object):
...
@@ -132,7 +132,7 @@ class ModelCard(object):
if
pretrained_model_name_or_path
in
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
:
if
pretrained_model_name_or_path
in
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
:
# For simplicity we use the same pretrained url than the configuration files but with a different suffix (model_card.json)
# For simplicity we use the same pretrained url than the configuration files but with a different suffix (model_card.json)
model_card_file
=
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
model_card_file
=
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
model_card_file
.
replace
(
CONFIG_NAME
,
MODEL_CARD_NAME
)
model_card_file
=
model_card_file
.
replace
(
CONFIG_NAME
,
MODEL_CARD_NAME
)
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
model_card_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MODEL_CARD_NAME
)
model_card_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MODEL_CARD_NAME
)
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
...
@@ -143,13 +143,11 @@ class ModelCard(object):
...
@@ -143,13 +143,11 @@ class ModelCard(object):
try
:
try
:
resolved_model_card_file
=
cached_path
(
model_card_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resolved_model_card_file
=
cached_path
(
model_card_file
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
)
proxies
=
proxies
,
resume_download
=
resume_download
)
if
resolved_model_card_file
==
model_card_file
:
if
resolved_model_card_file
==
model_card_file
:
logger
.
info
(
"loading model card file {}"
.
format
(
model_card_file
))
logger
.
info
(
"loading model card file {}"
.
format
(
model_card_file
))
else
:
else
:
logger
.
info
(
"loading model card file {} from cache at {}"
.
format
(
logger
.
info
(
"loading model card file {} from cache at {}"
.
format
(
model_card_file
,
resolved_model_card_file
))
model_card_file
,
resolved_model_card_file
))
# Load model card
# Load model card
model_card
=
cls
.
from_json_file
(
resolved_model_card_file
)
model_card
=
cls
.
from_json_file
(
resolved_model_card_file
)
...
@@ -164,9 +162,15 @@ class ModelCard(object):
...
@@ -164,9 +162,15 @@ class ModelCard(object):
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
', '
.
join
(
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
', '
.
join
(
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
model_card_file
,
MODEL_CARD_NAME
))
model_card_file
,
MODEL_CARD_NAME
))
logger
.
warning
(
"Creating an empty model card."
)
logger
.
warning
(
"Creating an empty model card."
)
# We fall back on creating an empty model card
model_card
=
cls
()
except
json
.
JSONDecodeError
:
logger
.
warning
(
"Couldn't reach server at '{}' to download model card file or "
"model card file is not a valid JSON file. "
"Please check network or file content here: {}."
.
format
(
model_card_file
,
resolved_model_card_file
))
logger
.
warning
(
"Creating an empty model card."
)
# We fall back on creating an empty model card
# We fall back on creating an empty model card
model_card
=
cls
()
model_card
=
cls
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment