Unverified Commit 164bdef8 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

ENH [`AutoQuantizer`]: enhance trainer + not supported quant methods (#28991)

* enhance trainer + not support quant methods

* remove all old logic

* add version
parent 1d12b8bc
......@@ -4190,6 +4190,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
logger.warning_once(warn_string)
@property
def _is_quantized_training_enabled(self):
logger.warning(
"`_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead",
FutureWarning,
)
if not hasattr(self, "hf_quantizer"):
return False
return self.hf_quantizer.is_trainable
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
......
......@@ -176,7 +176,6 @@ class HfQuantizer(ABC):
kwargs (`dict`, *optional*):
The keyword arguments that are passed along `_process_model_after_weight_loading`.
"""
model._is_quantized_training_enabled = self.is_trainable
return self._process_model_after_weight_loading(model, **kwargs)
@abstractmethod
......
......@@ -289,7 +289,6 @@ class Bnb4BitHfQuantizer(HfQuantizer):
# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = self.is_trainable
model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable
return model
......
......@@ -205,7 +205,6 @@ class Bnb8BitHfQuantizer(HfQuantizer):
unexpected_keys.remove(fp16_statistics_key)
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = self.is_trainable
model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable
return model
......
......@@ -420,6 +420,9 @@ class Trainer:
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)
_quantization_method_supports_training = (
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
)
# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model(model):
......@@ -428,10 +431,11 @@ class Trainer:
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
" for more details"
)
elif _is_quantized_and_base_model and not getattr(model, "_is_quantized_training_enabled", False):
elif _is_quantized_and_base_model and not _quantization_method_supports_training:
raise ValueError(
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
" but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
)
self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment