Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e9d0bc02
Unverified
Commit
e9d0bc02
authored
Apr 18, 2020
by
Patrick von Platen
Committed by
GitHub
Apr 17, 2020
Browse files
[Config, Serialization] more readable config serialization (#3797)
* better config serialization * finish configuration utils
parent
8b63a01d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
5 deletions
+38
-5
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+38
-5
No files found.
src/transformers/configuration_utils.py
View file @
e9d0bc02
...
@@ -141,7 +141,7 @@ class PretrainedConfig(object):
...
@@ -141,7 +141,7 @@ class PretrainedConfig(object):
# If we save using the predefined names, we can load using `from_pretrained`
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file
=
os
.
path
.
join
(
save_directory
,
CONFIG_NAME
)
output_config_file
=
os
.
path
.
join
(
save_directory
,
CONFIG_NAME
)
self
.
to_json_file
(
output_config_file
)
self
.
to_json_file
(
output_config_file
,
use_diff
=
True
)
logger
.
info
(
"Configuration saved in {}"
.
format
(
output_config_file
))
logger
.
info
(
"Configuration saved in {}"
.
format
(
output_config_file
))
@
classmethod
@
classmethod
...
@@ -353,6 +353,29 @@ class PretrainedConfig(object):
...
@@ -353,6 +353,29 @@ class PretrainedConfig(object):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"{} {}"
.
format
(
self
.
__class__
.
__name__
,
self
.
to_json_string
())
return
"{} {}"
.
format
(
self
.
__class__
.
__name__
,
self
.
to_json_string
())
def
to_diff_dict
(
self
):
"""
Removes all attributes from config which correspond to the default
config attributes for better readability and serializes to a Python
dictionary.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
config_dict
=
self
.
to_dict
()
# get the default config dict
default_config_dict
=
PretrainedConfig
().
to_dict
()
serializable_config_dict
=
{}
# only serialize values that differ from the default config
for
key
,
value
in
config_dict
.
items
():
if
key
not
in
default_config_dict
or
value
!=
default_config_dict
[
key
]:
serializable_config_dict
[
key
]
=
value
return
serializable_config_dict
def
to_dict
(
self
):
def
to_dict
(
self
):
"""
"""
Serializes this instance to a Python dictionary.
Serializes this instance to a Python dictionary.
...
@@ -365,25 +388,35 @@ class PretrainedConfig(object):
...
@@ -365,25 +388,35 @@ class PretrainedConfig(object):
output
[
"model_type"
]
=
self
.
__class__
.
model_type
output
[
"model_type"
]
=
self
.
__class__
.
model_type
return
output
return
output
def
to_json_string
(
self
):
def
to_json_string
(
self
,
use_diff
=
True
):
"""
"""
Serializes this instance to a JSON string.
Serializes this instance to a JSON string.
Args:
use_diff (:obj:`bool`):
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
Returns:
Returns:
:obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
:obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
"""
"""
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
if
use_diff
is
True
:
config_dict
=
self
.
to_diff_dict
()
else
:
config_dict
=
self
.
to_dict
()
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
):
def
to_json_file
(
self
,
json_file_path
,
use_diff
=
True
):
"""
"""
Save this instance to a json file.
Save this instance to a json file.
Args:
Args:
json_file_path (:obj:`string`):
json_file_path (:obj:`string`):
Path to the JSON file in which this configuration instance's parameters will be saved.
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (:obj:`bool`):
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
"""
"""
with
open
(
json_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
json_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
self
.
to_json_string
())
writer
.
write
(
self
.
to_json_string
(
use_diff
=
use_diff
))
def
update
(
self
,
config_dict
:
Dict
):
def
update
(
self
,
config_dict
:
Dict
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment