Unverified Commit 3319eb54 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: legacy mode is only triggered when `generation_config` is untouched (#25962)

parent 18abc756
......@@ -55,12 +55,10 @@ When you load a model explicitly, you can inspect the generation configuration t
>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> model.generation_config # doctest: +IGNORE_RESULT
>>> model.generation_config
GenerationConfig {
"_from_model_config": true,
"bos_token_id": 50256,
"eos_token_id": 50256,
"transformers_version": "4.26.0.dev0"
}
```
......
......@@ -34,6 +34,7 @@ from ..utils import (
logger = logging.get_logger(__name__)
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
class GenerationConfig(PushToHubMixin):
......@@ -315,20 +316,19 @@ class GenerationConfig(PushToHubMixin):
# Validate the values of the attributes
self.validate(is_init=True)
def __hash__(self):
return hash(self.to_json_string(ignore_metadata=True))
def __eq__(self, other):
if not isinstance(other, GenerationConfig):
return False
self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
# ignore metadata
for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"):
self_dict.pop(metadata_field, None)
other_dict.pop(metadata_field, None)
return self_dict == other_dict
self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True)
other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True)
return self_without_metadata == other_without_metadata
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
def validate(self, is_init=False):
"""
......@@ -729,7 +729,9 @@ class GenerationConfig(PushToHubMixin):
else:
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
return cls.from_dict(config_dict, **kwargs)
config = cls.from_dict(config_dict, **kwargs)
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
return config
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
......@@ -814,8 +816,12 @@ class GenerationConfig(PushToHubMixin):
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
# Fields to ignore at serialization time
if "_commit_hash" in output:
del output["_commit_hash"]
if "_original_object_hash" in output:
del output["_original_object_hash"]
# Transformers version when serializing this file
output["transformers_version"] = __version__
......@@ -823,7 +829,7 @@ class GenerationConfig(PushToHubMixin):
self.dict_torch_dtype_to_str(output)
return output
def to_json_string(self, use_diff: bool = True) -> str:
def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
"""
Serializes this instance to a JSON string.
......@@ -831,6 +837,8 @@ class GenerationConfig(PushToHubMixin):
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
is serialized to JSON string.
ignore_metadata (`bool`, *optional*, defaults to `False`):
Whether to ignore the metadata fields present in the instance
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
......@@ -839,6 +847,11 @@ class GenerationConfig(PushToHubMixin):
config_dict = self.to_diff_dict()
else:
config_dict = self.to_dict()
if ignore_metadata:
for metadata_field in METADATA_FIELDS:
config_dict.pop(metadata_field, None)
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
......@@ -882,6 +895,7 @@ class GenerationConfig(PushToHubMixin):
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
return config
def update(self, **kwargs):
......
......@@ -310,16 +310,20 @@ class FlaxGenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
self.generation_config
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation )"
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
......
......@@ -716,16 +716,20 @@ class TFGenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
self.generation_config
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation )"
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
......
......@@ -1409,16 +1409,20 @@ class GenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
self.generation_config
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation )"
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
......
......@@ -2880,7 +2880,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# Generation config max_length != 20 -> no warning
with warnings.catch_warnings(record=True) as warning_list:
# generation_config is modified -> legacy mode is disabled = generation_config takes precedence
model.generation_config.max_length = 10
model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence
model.generate(input_ids)
self.assertEqual(len(warning_list), 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment