Unverified Commit a989c6c6 authored by Omar Sanseviero's avatar Omar Sanseviero Committed by GitHub
Browse files

Don't allow passing `load_in_8bit` and `load_in_4bit` at the same time (#28266)



* Update quantization_config.py

* Style

* Protect from setting directly

* add tests

* Update tests/quantization/bnb/test_4bit.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent cd2eb8cb
...@@ -212,8 +212,12 @@ class BitsAndBytesConfig(QuantizationConfigMixin): ...@@ -212,8 +212,12 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
**kwargs, **kwargs,
): ):
self.quant_method = QuantizationMethod.BITS_AND_BYTES self.quant_method = QuantizationMethod.BITS_AND_BYTES
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit if load_in_4bit and load_in_8bit:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_8bit = load_in_8bit
self._load_in_4bit = load_in_4bit
self.llm_int8_threshold = llm_int8_threshold self.llm_int8_threshold = llm_int8_threshold
self.llm_int8_skip_modules = llm_int8_skip_modules self.llm_int8_skip_modules = llm_int8_skip_modules
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
...@@ -232,6 +236,26 @@ class BitsAndBytesConfig(QuantizationConfigMixin): ...@@ -232,6 +236,26 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
self.post_init() self.post_init()
@property
def load_in_4bit(self):
return self._load_in_4bit
@load_in_4bit.setter
def load_in_4bit(self, value: bool):
if self.load_in_8bit and value:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_4bit = value
@property
def load_in_8bit(self):
return self._load_in_8bit
@load_in_8bit.setter
def load_in_8bit(self, value: bool):
if self.load_in_4bit and value:
raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
self._load_in_8bit = value
def post_init(self): def post_init(self):
r""" r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
......
...@@ -648,3 +648,18 @@ class GPTSerializationTest(BaseSerializationTest): ...@@ -648,3 +648,18 @@ class GPTSerializationTest(BaseSerializationTest):
""" """
model_name = "gpt2-xl" model_name = "gpt2-xl"
@require_bitsandbytes
@require_accelerate
@require_torch_gpu
@slow
class Bnb4BitTestBasicConfigTest(unittest.TestCase):
def test_load_in_4_and_8_bit_fails(self):
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
AutoModelForCausalLM.from_pretrained("facebook/opt-125m", load_in_4bit=True, load_in_8bit=True)
def test_set_load_in_8_bit(self):
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
quantization_config.load_in_8bit = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment