"vscode:/vscode.git/clone" did not exist on "7bdfd29a358ddbf9dc2544968f5bf9b700773be9"
Unverified Commit 7576cd38 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Check bnb_4bit_quant_storage for bitsandbytes (#10642)

parent 9a99273b
......@@ -20,6 +20,7 @@ class BitsAndBytesConfig(QuantizationConfig):
load_in_8bit: bool = False,
load_in_4bit: bool = True,
bnb_4bit_compute_dtype: str = "float32",
bnb_4bit_quant_storage: str = "uint8",
bnb_4bit_quant_type: str = "fp4",
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
......@@ -31,6 +32,7 @@ class BitsAndBytesConfig(QuantizationConfig):
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
self.bnb_4bit_quant_type = bnb_4bit_quant_type
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
......@@ -38,10 +40,15 @@ class BitsAndBytesConfig(QuantizationConfig):
self.llm_int8_skip_modules = llm_int8_skip_modules or []
self.llm_int8_threshold = llm_int8_threshold
if self.bnb_4bit_quant_storage not in ["uint8"]:
raise ValueError("Unsupported bnb_4bit_quant_storage: "
f"{self.bnb_4bit_quant_storage}")
def __repr__(self) -> str:
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
f"load_in_4bit={self.load_in_4bit}, "
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
......@@ -80,6 +87,9 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_compute_dtype = get_safe_value(config,
["bnb_4bit_compute_dtype"],
default_value="float32")
bnb_4bit_quant_storage = get_safe_value(config,
["bnb_4bit_quant_storage"],
default_value="uint8")
bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
default_value="fp4")
bnb_4bit_use_double_quant = get_safe_value(
......@@ -99,6 +109,7 @@ class BitsAndBytesConfig(QuantizationConfig):
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_quant_storage=bnb_4bit_quant_storage,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment