"vscode:/vscode.git/clone" did not exist on "1f47a24aa1fbaefcac9d0ffbfa5fff30867a8c92"
Unverified Commit 7a9f1b5c authored by Kevin Canwen Xu's avatar Kevin Canwen Xu Committed by GitHub
Browse files

Store transformers version info when saving the model (#9421)



* Store transformers version info when saving the model

* Store transformers version info when saving the model

* fix format

* fix format

* fix format

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update configuration_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent ecfcac22
...@@ -21,6 +21,7 @@ import json ...@@ -21,6 +21,7 @@ import json
import os import os
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from . import __version__
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
from .utils import logging from .utils import logging
...@@ -234,6 +235,9 @@ class PretrainedConfig(object): ...@@ -234,6 +235,9 @@ class PretrainedConfig(object):
# Name or path to the pretrained checkpoint # Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", "")) self._name_or_path = str(kwargs.pop("name_or_path", ""))
# Drop the transformers version info
kwargs.pop("transformers_version", None)
# Additional attributes without default values # Additional attributes without default values
for key, value in kwargs.items(): for key, value in kwargs.items():
try: try:
...@@ -520,6 +524,7 @@ class PretrainedConfig(object): ...@@ -520,6 +524,7 @@ class PretrainedConfig(object):
for key, value in config_dict.items(): for key, value in config_dict.items():
if ( if (
key not in default_config_dict key not in default_config_dict
or key == "transformers_version"
or value != default_config_dict[key] or value != default_config_dict[key]
or (key in class_config_dict and value != class_config_dict[key]) or (key in class_config_dict and value != class_config_dict[key])
): ):
...@@ -537,6 +542,10 @@ class PretrainedConfig(object): ...@@ -537,6 +542,10 @@ class PretrainedConfig(object):
output = copy.deepcopy(self.__dict__) output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"): if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type output["model_type"] = self.__class__.model_type
# Transformers version when serializing the model
output["transformers_version"] = __version__
return output return output
def to_json_string(self, use_diff: bool = True) -> str: def to_json_string(self, use_diff: bool = True) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment