Unverified Commit 3e41cf13 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: Load generation config when `device_map` is passed (#25413)

parent d0839f1a
...@@ -2849,9 +2849,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2849,9 +2849,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"'sequential'." "'sequential'."
) )
kwargs = {"no_split_module_classes": no_split_modules} device_map_kwargs = {"no_split_module_classes": no_split_modules}
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
kwargs["special_dtypes"] = special_dtypes device_map_kwargs["special_dtypes"] = special_dtypes
elif len(special_dtypes) > 0: elif len(special_dtypes) > 0:
logger.warning( logger.warning(
"This model has some weights that should be kept in higher precision, you need to upgrade " "This model has some weights that should be kept in higher precision, you need to upgrade "
...@@ -2863,12 +2863,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2863,12 +2863,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype=target_dtype, dtype=target_dtype,
low_zero=(device_map == "balanced_low_0"), low_zero=(device_map == "balanced_low_0"),
max_memory=max_memory, max_memory=max_memory,
**kwargs, **device_map_kwargs,
) )
kwargs["max_memory"] = max_memory device_map_kwargs["max_memory"] = max_memory
# Make sure tied weights are tied before creating the device map. # Make sure tied weights are tied before creating the device map.
model.tie_weights() model.tie_weights()
device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs) device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
if load_in_8bit or load_in_4bit: if load_in_8bit or load_in_4bit:
# The LM head / tied weights or any last module can stay on disk / CPU # The LM head / tied weights or any last module can stay on disk / CPU
...@@ -2966,7 +2966,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2966,7 +2966,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model.eval() model.eval()
# If it is a model with generation capabilities, attempt to load the generation config # If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate(): if model.can_generate() and pretrained_model_name_or_path is not None:
try: try:
model.generation_config = GenerationConfig.from_pretrained( model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -2982,7 +2982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2982,7 +2982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_from_pipeline=from_pipeline, _from_pipeline=from_pipeline,
**kwargs, **kwargs,
) )
except (OSError, TypeError): except OSError:
logger.info( logger.info(
"Generation config file not found, using a generation config created from the model config." "Generation config file not found, using a generation config created from the model config."
) )
...@@ -2990,10 +2990,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2990,10 +2990,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Dispatch model with hooks on all devices if necessary # Dispatch model with hooks on all devices if necessary
if device_map is not None: if device_map is not None:
kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index} device_map_kwargs = {
"device_map": device_map,
"offload_dir": offload_folder,
"offload_index": offload_index,
}
if "skip_keys" in inspect.signature(dispatch_model).parameters: if "skip_keys" in inspect.signature(dispatch_model).parameters:
kwargs["skip_keys"] = model._skip_keys_device_placement device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **kwargs) dispatch_model(model, **device_map_kwargs)
if output_loading_info: if output_loading_info:
if loading_info is None: if loading_info is None:
......
...@@ -1036,6 +1036,20 @@ class ModelUtilsTest(TestCasePlus): ...@@ -1036,6 +1036,20 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__) self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
def test_generation_config_is_loaded_with_model(self):
# Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
# `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
# 1. Load without further parameters
model = AutoModelForCausalLM.from_pretrained("joaogante/tiny-random-gpt2-with-generation-config")
self.assertEqual(model.generation_config.transformers_version, "foo")
# 2. Load with `device_map`
model = AutoModelForCausalLM.from_pretrained(
"joaogante/tiny-random-gpt2-with-generation-config", device_map="auto"
)
self.assertEqual(model.generation_config.transformers_version, "foo")
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment