Unverified Commit 13ef14e1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix config silent copy in from_pretrained (#27043)

* Fix config modeling utils

* fix more

* fix attn mask bug

* Update src/transformers/modeling_utils.py
parent 9da45171
...@@ -3135,6 +3135,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3135,6 +3135,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
with ContextManagers(init_contexts): with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
# Check first if we are `from_pt` # Check first if we are `from_pt`
if use_keep_in_fp32_modules: if use_keep_in_fp32_modules:
if is_accelerate_available(): if is_accelerate_available():
...@@ -3193,7 +3196,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3193,7 +3196,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
importlib.metadata.version("bitsandbytes") importlib.metadata.version("bitsandbytes")
) >= version.parse("0.37.0") ) >= version.parse("0.37.0")
model.config.quantization_config = quantization_config config.quantization_config = quantization_config
model.is_8bit_serializable = is_8bit_serializable model.is_8bit_serializable = is_8bit_serializable
if load_in_8bit and torch_dtype is None: if load_in_8bit and torch_dtype is None:
...@@ -3423,7 +3426,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3423,7 +3426,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if cls.main_input_name != "input_ids": if cls.main_input_name != "input_ids":
raise RuntimeError("We can only quantize pure text model.") raise RuntimeError("We can only quantize pure text model.")
quantizer.quantize_model(model, quantization_config.tokenizer) quantizer.quantize_model(model, quantization_config.tokenizer)
model.config.quantization_config = GPTQConfig.from_dict(quantizer.to_dict()) config.quantization_config = GPTQConfig.from_dict(quantizer.to_dict())
model._is_quantized_training_enabled = True model._is_quantized_training_enabled = True
if quantization_method_from_config == QuantizationMethod.GPTQ: if quantization_method_from_config == QuantizationMethod.GPTQ:
model = quantizer.post_init_model(model) model = quantizer.post_init_model(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment