Unverified Commit 164bdef8 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

ENH [`AutoQuantizer`]: enhance trainer + not supported quant methods (#28991)

* enhance trainer + not support quant methods

* remove all old logic

* add version
parent 1d12b8bc
...@@ -4190,6 +4190,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -4190,6 +4190,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
logger.warning_once(warn_string) logger.warning_once(warn_string)
@property
def _is_quantized_training_enabled(self):
logger.warning(
"`_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead",
FutureWarning,
)
if not hasattr(self, "hf_quantizer"):
return False
return self.hf_quantizer.is_trainable
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None: if PreTrainedModel.push_to_hub.__doc__ is not None:
......
...@@ -176,7 +176,6 @@ class HfQuantizer(ABC): ...@@ -176,7 +176,6 @@ class HfQuantizer(ABC):
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
The keyword arguments that are passed along `_process_model_after_weight_loading`. The keyword arguments that are passed along `_process_model_after_weight_loading`.
""" """
model._is_quantized_training_enabled = self.is_trainable
return self._process_model_after_weight_loading(model, **kwargs) return self._process_model_after_weight_loading(model, **kwargs)
@abstractmethod @abstractmethod
......
...@@ -289,7 +289,6 @@ class Bnb4BitHfQuantizer(HfQuantizer): ...@@ -289,7 +289,6 @@ class Bnb4BitHfQuantizer(HfQuantizer):
# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = self.is_trainable
model.is_loaded_in_4bit = True model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable model.is_4bit_serializable = self.is_serializable
return model return model
......
...@@ -205,7 +205,6 @@ class Bnb8BitHfQuantizer(HfQuantizer): ...@@ -205,7 +205,6 @@ class Bnb8BitHfQuantizer(HfQuantizer):
unexpected_keys.remove(fp16_statistics_key) unexpected_keys.remove(fp16_statistics_key)
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = self.is_trainable
model.is_loaded_in_8bit = True model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable model.is_8bit_serializable = self.is_serializable
return model return model
......
...@@ -420,6 +420,9 @@ class Trainer: ...@@ -420,6 +420,9 @@ class Trainer:
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr( _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False model, "_hf_peft_config_loaded", False
) )
_quantization_method_supports_training = (
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
)
# At this stage the model is already loaded # At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model(model): if _is_quantized_and_base_model and not _is_peft_model(model):
...@@ -428,10 +431,11 @@ class Trainer: ...@@ -428,10 +431,11 @@ class Trainer:
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft" " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
" for more details" " for more details"
) )
elif _is_quantized_and_base_model and not getattr(model, "_is_quantized_training_enabled", False): elif _is_quantized_and_base_model and not _quantization_method_supports_training:
raise ValueError( raise ValueError(
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. " " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
) )
self.is_fsdp_xla_enabled = args.fsdp_config["xla"] self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment