Unverified Commit 8dd4ce6f authored by Benjamin Ye's avatar Benjamin Ye Committed by GitHub
Browse files

[`BitsAndBytesConfig`] Warning for unused `kwargs` & safety checkers for...


[`BitsAndBytesConfig`] Warning for unused `kwargs` & safety checkers for `load_in_4bit` and `load_in_8bit` (#29761)

* added safety checkers for load_in_4bit and load_in_8bit on init, as well as their setters

* Update src/transformers/utils/quantization_config.py

typo correction for load_in_8bit setter checks
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 17e4467f
...@@ -278,6 +278,9 @@ class BitsAndBytesConfig(QuantizationConfigMixin): ...@@ -278,6 +278,9 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
else: else:
raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype") raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")
if kwargs:
logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.")
self.post_init() self.post_init()
@property @property
...@@ -286,6 +289,9 @@ class BitsAndBytesConfig(QuantizationConfigMixin): ...@@ -286,6 +289,9 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
@load_in_4bit.setter @load_in_4bit.setter
def load_in_4bit(self, value: bool): def load_in_4bit(self, value: bool):
if not isinstance(value, bool):
raise ValueError("load_in_4bit must be a boolean")
if self.load_in_8bit and value: if self.load_in_8bit and value:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_4bit = value self._load_in_4bit = value
...@@ -296,6 +302,9 @@ class BitsAndBytesConfig(QuantizationConfigMixin): ...@@ -296,6 +302,9 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
@load_in_8bit.setter @load_in_8bit.setter
def load_in_8bit(self, value: bool): def load_in_8bit(self, value: bool):
if not isinstance(value, bool):
raise ValueError("load_in_8bit must be a boolean")
if self.load_in_4bit and value: if self.load_in_4bit and value:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_8bit = value self._load_in_8bit = value
...@@ -304,6 +313,12 @@ class BitsAndBytesConfig(QuantizationConfigMixin): ...@@ -304,6 +313,12 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
r""" r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
""" """
if not isinstance(self.load_in_4bit, bool):
raise ValueError("load_in_4bit must be a boolean")
if not isinstance(self.load_in_8bit, bool):
raise ValueError("load_in_8bit must be a boolean")
if not isinstance(self.llm_int8_threshold, float): if not isinstance(self.llm_int8_threshold, float):
raise ValueError("llm_int8_threshold must be a float") raise ValueError("llm_int8_threshold must be a float")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment