Unverified Commit fa21ead7 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Awq`] Enable the possibility to skip quantization for some target modules (#27950)

* v1

* add docstring

* add tests

* add awq 0.1.8

* oops

* fix test
parent 29e7a1e1
......@@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
RUN python3 -m pip install --no-cache-dir einops
# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl
# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
......
......@@ -3575,6 +3575,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if quantization_config is None:
quantization_config = AwqConfig.from_dict(config.quantization_config)
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
model, has_been_replaced = replace_with_awq_linear(
model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert
)
......
......@@ -564,6 +564,10 @@ class AwqConfig(QuantizationConfigMixin):
The Maximum sequence length to generate when using fusing.
modules_to_fuse (`dict`, *optional*, default to `None`):
Overwrite the natively supported fusing scheme with the one specified by the users.
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
"""
def __init__(
......@@ -576,6 +580,7 @@ class AwqConfig(QuantizationConfigMixin):
do_fuse: Optional[bool] = None,
fuse_max_seq_len: Optional[int] = None,
modules_to_fuse: Optional[dict] = None,
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ
......@@ -586,6 +591,7 @@ class AwqConfig(QuantizationConfigMixin):
self.version = version
self.backend = backend
self.fuse_max_seq_len = fuse_max_seq_len
self.modules_to_not_convert = modules_to_not_convert
self.modules_to_fuse = modules_to_fuse
if do_fuse is None:
......@@ -638,6 +644,19 @@ class AwqConfig(QuantizationConfigMixin):
f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)
if self.modules_to_not_convert is not None:
awq_version_supports_non_conversion = False
MIN_AWQ_VERSION = "0.1.8"
if is_auto_awq_available():
awq_version_supports_non_conversion = version.parse(
importlib.metadata.version("autoawq")
) >= version.parse(MIN_AWQ_VERSION)
if not awq_version_supports_non_conversion:
raise ValueError(
f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)
if self.do_fuse and self.modules_to_fuse is not None:
required_keys = [
"hidden_size",
......
......@@ -88,6 +88,7 @@ class AwqConfigTest(unittest.TestCase):
class AwqTest(unittest.TestCase):
model_name = "TheBloke/Mistral-7B-v0.1-AWQ"
dummy_transformers_model_name = "bigscience/bloom-560m"
model_with_no_k_proj_quantized = "hf-internal-testing/opt-125m-awq-no-k-proj"
input_text = "Hello my name is"
......@@ -223,6 +224,24 @@ class AwqTest(unittest.TestCase):
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_quantized_model_no_k_proj_quantized(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
dummy_input = torch.LongTensor([[0, 1, 0]]).to(torch_device)
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_with_no_k_proj_quantized).to(torch_device)
self.assertTrue(isinstance(quantized_model.model.decoder.layers[0].self_attn.k_proj, torch.nn.Linear))
self.assertFalse(isinstance(quantized_model.model.decoder.layers[0].self_attn.v_proj, torch.nn.Linear))
EXPECTED_OUTPUT = torch.LongTensor([[0, 1, 0, 50118, 50118, 133, 248, 12, 134, 16, 10, 372, 2031]]).to(
torch_device
)
output = quantized_model.generate(dummy_input, max_new_tokens=10)
self.assertTrue((EXPECTED_OUTPUT == output).all())
@slow
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment