"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "99a90e43d421369357815b21771f5211c2528667"
Unverified Commit f34372a9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[PretrainedConfig] Fix save pretrained config for edge case (#7943)



* fix config save

* add test

* add config class variable and another test

* line break

* fix fsmt and typo

* god am I making many errors today :-/

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent cc2e312c
...@@ -70,6 +70,7 @@ class EncoderDecoderConfig(PretrainedConfig): ...@@ -70,6 +70,7 @@ class EncoderDecoderConfig(PretrainedConfig):
>>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config) >>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)
""" """
model_type = "encoder_decoder" model_type = "encoder_decoder"
is_composition = True
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -126,9 +126,9 @@ class FSMTConfig(PretrainedConfig): ...@@ -126,9 +126,9 @@ class FSMTConfig(PretrainedConfig):
# update the defaults from config file # update the defaults from config file
def __init__( def __init__(
self, self,
langs, langs=["en", "de"],
src_vocab_size, src_vocab_size=42024,
tgt_vocab_size, tgt_vocab_size=42024,
activation_function="relu", activation_function="relu",
d_model=1024, d_model=1024,
max_length=200, max_length=200,
......
...@@ -77,6 +77,7 @@ RAG_CONFIG_DOC = r""" ...@@ -77,6 +77,7 @@ RAG_CONFIG_DOC = r"""
@add_start_docstrings(RAG_CONFIG_DOC) @add_start_docstrings(RAG_CONFIG_DOC)
class RagConfig(PretrainedConfig): class RagConfig(PretrainedConfig):
model_type = "rag" model_type = "rag"
is_composition = True
def __init__( def __init__(
self, self,
......
...@@ -41,6 +41,10 @@ class PretrainedConfig(object): ...@@ -41,6 +41,10 @@ class PretrainedConfig(object):
Class attributes (overridden by derived classes) Class attributes (overridden by derived classes)
- **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to - **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to
recreate the correct object in :class:`~transformers.AutoConfig`. recreate the correct object in :class:`~transformers.AutoConfig`.
- **is_composition** (:obj:`bool`): Whether the config class is composed of multiple
sub-configs. In this case the config has to be initialized from two or more configs of
type :class:`~transformers.PretrainedConfig` like: :class:`~transformers.EncoderDecoderConfig` or
:class:`~RagConfig`.
Args: Args:
name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`): name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
...@@ -145,6 +149,7 @@ class PretrainedConfig(object): ...@@ -145,6 +149,7 @@ class PretrainedConfig(object):
use BFloat16 scalars (only used by some TensorFlow models). use BFloat16 scalars (only used by some TensorFlow models).
""" """
model_type: str = "" model_type: str = ""
is_composition: bool = False
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Attributes with defaults # Attributes with defaults
...@@ -476,11 +481,18 @@ class PretrainedConfig(object): ...@@ -476,11 +481,18 @@ class PretrainedConfig(object):
# get the default config dict # get the default config dict
default_config_dict = PretrainedConfig().to_dict() default_config_dict = PretrainedConfig().to_dict()
# get class specific config dict
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
serializable_config_dict = {} serializable_config_dict = {}
# only serialize values that differ from the default config # only serialize values that differ from the default config
for key, value in config_dict.items(): for key, value in config_dict.items():
if key not in default_config_dict or value != default_config_dict[key]: if (
key not in default_config_dict
or value != default_config_dict[key]
or (key in class_config_dict and value != class_config_dict[key])
):
serializable_config_dict[key] = value serializable_config_dict[key] = value
return serializable_config_dict return serializable_config_dict
......
...@@ -66,9 +66,16 @@ class ConfigTester(object): ...@@ -66,9 +66,16 @@ class ConfigTester(object):
self.parent.assertEqual(len(config.id2label), 3) self.parent.assertEqual(len(config.id2label), 3)
self.parent.assertEqual(len(config.label2id), 3) self.parent.assertEqual(len(config.label2id), 3)
def check_config_can_be_init_without_params(self):
if self.config_class.is_composition:
return
config = self.config_class()
self.parent.assertIsNotNone(config)
def run_common_tests(self): def run_common_tests(self):
self.create_and_test_config_common_properties() self.create_and_test_config_common_properties()
self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file() self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_from_and_save_pretrained()
self.create_and_test_config_with_num_labels() self.create_and_test_config_with_num_labels()
self.check_config_can_be_init_without_params()
...@@ -901,6 +901,15 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -901,6 +901,15 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_model_with_attn_mask(*config_and_inputs) self.model_tester.check_model_with_attn_mask(*config_and_inputs)
def test_config_save(self):
config = self.model_tester.prepare_config_and_inputs()[0]
config.add_cross_attention = False
with tempfile.TemporaryDirectory() as tmp_dirname:
config.save_pretrained(tmp_dirname)
config = ProphetNetConfig.from_pretrained(tmp_dirname)
self.assertFalse(config.add_cross_attention)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision") @unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_fp16_forward(self): def test_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment