"Resource/vscode:/vscode.git/clone" did not exist on "cdb77bf7670bd8db9c38cb1b445e9bb49e88559d"
Unverified Commit 4fc708f9 authored by Ilyas Moutawwakil's avatar Ilyas Moutawwakil Committed by GitHub
Browse files

Exllama kernels support for AWQ models (#28634)



* added exllama kernels support for awq models

* doc

* style

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* refactor

* moved exllama post init to after device dispatching

* bump autoawq version

* added exllama test

* style

* configurable exllama kernels

* copy exllama_config from gptq

* moved exllama version check to post init

* moved to quantization dockerfile

---------
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
parent 81c8191b
...@@ -43,7 +43,7 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt ...@@ -43,7 +43,7 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2 RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
# Add autoawq for quantization testing # Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl
# When installing in editable mode, `transformers` is not recognized as a package. # When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers. # this line must be added in order for python to be aware of transformers.
......
...@@ -18,7 +18,11 @@ from ..utils import _LazyModule ...@@ -18,7 +18,11 @@ from ..utils import _LazyModule
_import_structure = { _import_structure = {
"aqlm": ["replace_with_aqlm_linear"], "aqlm": ["replace_with_aqlm_linear"],
"awq": ["fuse_awq_modules", "replace_with_awq_linear"], "awq": [
"fuse_awq_modules",
"post_init_awq_exllama_modules",
"replace_with_awq_linear",
],
"bitsandbytes": [ "bitsandbytes": [
"get_keys_to_not_convert", "get_keys_to_not_convert",
"replace_8bit_linear", "replace_8bit_linear",
...@@ -82,7 +86,11 @@ _import_structure = { ...@@ -82,7 +86,11 @@ _import_structure = {
if TYPE_CHECKING: if TYPE_CHECKING:
from .aqlm import replace_with_aqlm_linear from .aqlm import replace_with_aqlm_linear
from .awq import fuse_awq_modules, replace_with_awq_linear from .awq import (
fuse_awq_modules,
post_init_awq_exllama_modules,
replace_with_awq_linear,
)
from .bitsandbytes import ( from .bitsandbytes import (
get_keys_to_not_convert, get_keys_to_not_convert,
replace_8bit_linear, replace_8bit_linear,
......
...@@ -15,7 +15,12 @@ ...@@ -15,7 +15,12 @@
from ..activations import ACT2FN from ..activations import ACT2FN
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_awq_available, is_torch_available from ..utils import is_auto_awq_available, is_torch_available
from ..utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion from ..utils.quantization_config import (
AwqBackendPackingMethod,
AwqConfig,
AWQLinearVersion,
ExllamaVersion,
)
if is_torch_available(): if is_torch_available():
...@@ -91,13 +96,30 @@ def replace_with_awq_linear( ...@@ -91,13 +96,30 @@ def replace_with_awq_linear(
) )
if backend == AwqBackendPackingMethod.AUTOAWQ: if backend == AwqBackendPackingMethod.AUTOAWQ:
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV if quantization_config.version == AWQLinearVersion.GEMM:
elif backend == AwqBackendPackingMethod.LLMAWQ: from awq.modules.linear.gemm import WQLinear_GEMM
from awq.quantize.qmodule import WQLinear
if backend == AwqBackendPackingMethod.AUTOAWQ: target_cls = WQLinear_GEMM
target_cls = WQLinear_GEMM if quantization_config.version == AWQLinearVersion.GEMM else WQLinear_GEMV elif quantization_config.version == AWQLinearVersion.GEMV:
from awq.modules.linear.gemv import WQLinear_GEMV
target_cls = WQLinear_GEMV
elif quantization_config.version == AWQLinearVersion.EXLLAMA:
if quantization_config.exllama_config["version"] == ExllamaVersion.ONE:
from awq.modules.linear.exllama import WQLinear_Exllama
target_cls = WQLinear_Exllama
elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO:
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
target_cls = WQLinear_ExllamaV2
else: else:
raise ValueError(f"Unrecognized Exllama version: {quantization_config.exllama_config['version']}")
else:
raise ValueError(f"Unrecognized AWQ version: {quantization_config.version}")
else:
from awq.quantize.qmodule import WQLinear
target_cls = WQLinear target_cls = WQLinear
for name, module in model.named_children(): for name, module in model.named_children():
...@@ -372,3 +394,28 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na ...@@ -372,3 +394,28 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
setattr(parent, child_name, fused_attention_layer.to(previous_device)) setattr(parent, child_name, fused_attention_layer.to(previous_device))
del q_proj, k_proj, v_proj, o_proj del q_proj, k_proj, v_proj, o_proj
def post_init_awq_exllama_modules(model, exllama_config):
"""
Runs post init for Exllama layers which performs:
- Weights unpacking, reordering and repacking
- Devices scratch space allocation
"""
if exllama_config["version"] == ExllamaVersion.ONE:
from awq.modules.linear.exllama import exllama_post_init
model = exllama_post_init(model)
elif exllama_config["version"] == ExllamaVersion.TWO:
from awq.modules.linear.exllamav2 import exllamav2_post_init
model = exllamav2_post_init(
model,
max_input_len=exllama_config["max_input_len"],
max_batch_size=exllama_config["max_batch_size"],
)
else:
raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")
return model
...@@ -23,6 +23,7 @@ if TYPE_CHECKING: ...@@ -23,6 +23,7 @@ if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from ..utils import is_accelerate_available, is_auto_awq_available, is_torch_available, logging from ..utils import is_accelerate_available, is_auto_awq_available, is_torch_available, logging
from ..utils.quantization_config import AWQLinearVersion
if is_torch_available(): if is_torch_available():
...@@ -98,12 +99,22 @@ class AwqQuantizer(HfQuantizer): ...@@ -98,12 +99,22 @@ class AwqQuantizer(HfQuantizer):
model = fuse_awq_modules(model, self.quantization_config) model = fuse_awq_modules(model, self.quantization_config)
model._awq_is_fused = True # TODO: consider storing this flag in model.config instead model._awq_is_fused = True # TODO: consider storing this flag in model.config instead
if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
from ..integrations import post_init_awq_exllama_modules
model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)
@property @property
def is_serializable(self): def is_serializable(self):
# AWQ through auto-awq has been always serializable, except if the model is fused. # AWQ through auto-awq has been always serializable, except if the model is fused.
if self.quantization_config.do_fuse: if self.quantization_config.do_fuse:
logger.warning("You cannot save an AWQ model that uses fused modules!") logger.warning("You cannot save an AWQ model that uses fused modules!")
return False return False
if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
logger.warning("You cannot save an AWQ model that uses Exllama backend!")
return False
return True return True
@property @property
......
...@@ -44,6 +44,7 @@ class QuantizationMethod(str, Enum): ...@@ -44,6 +44,7 @@ class QuantizationMethod(str, Enum):
class AWQLinearVersion(str, Enum): class AWQLinearVersion(str, Enum):
GEMM = "gemm" GEMM = "gemm"
GEMV = "gemv" GEMV = "gemv"
EXLLAMA = "exllama"
@staticmethod @staticmethod
def from_str(version: str): def from_str(version: str):
...@@ -52,6 +53,8 @@ class AWQLinearVersion(str, Enum): ...@@ -52,6 +53,8 @@ class AWQLinearVersion(str, Enum):
return AWQLinearVersion.GEMM return AWQLinearVersion.GEMM
elif version == "gemv": elif version == "gemv":
return AWQLinearVersion.GEMV return AWQLinearVersion.GEMV
elif version == "exllama":
return AWQLinearVersion.EXLLAMA
else: else:
raise ValueError(f"Unknown AWQLinearVersion {version}") raise ValueError(f"Unknown AWQLinearVersion {version}")
...@@ -606,7 +609,7 @@ class AwqConfig(QuantizationConfigMixin): ...@@ -606,7 +609,7 @@ class AwqConfig(QuantizationConfigMixin):
Whether to use zero point quantization. Whether to use zero point quantization.
version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`): version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`):
The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise, The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise,
GEMV is better (e.g. < 8 ) GEMV is better (e.g. < 8 ). GEMM models are compatible with Exllama kernels.
backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`): backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`):
The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users
that quantize their own models using `llm-awq` library. that quantize their own models using `llm-awq` library.
...@@ -620,6 +623,10 @@ class AwqConfig(QuantizationConfigMixin): ...@@ -620,6 +623,10 @@ class AwqConfig(QuantizationConfigMixin):
The list of modules to not quantize, useful for quantizing models that explicitly require to have The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models. Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
exllama_config (`Dict[str, Any]`, *optional*):
You can specify the version of the exllama kernel through the `version` key, the maximum sequence
length through the `max_input_len` key, and the maximum batch size through the `max_batch_size` key.
Defaults to `{"version": 2, "max_input_len": 2048, "max_batch_size": 8}` if unset.
""" """
def __init__( def __init__(
...@@ -633,6 +640,7 @@ class AwqConfig(QuantizationConfigMixin): ...@@ -633,6 +640,7 @@ class AwqConfig(QuantizationConfigMixin):
fuse_max_seq_len: Optional[int] = None, fuse_max_seq_len: Optional[int] = None,
modules_to_fuse: Optional[dict] = None, modules_to_fuse: Optional[dict] = None,
modules_to_not_convert: Optional[List] = None, modules_to_not_convert: Optional[List] = None,
exllama_config: Optional[Dict[str, int]] = None,
**kwargs, **kwargs,
): ):
self.quant_method = QuantizationMethod.AWQ self.quant_method = QuantizationMethod.AWQ
...@@ -644,6 +652,7 @@ class AwqConfig(QuantizationConfigMixin): ...@@ -644,6 +652,7 @@ class AwqConfig(QuantizationConfigMixin):
self.backend = backend self.backend = backend
self.fuse_max_seq_len = fuse_max_seq_len self.fuse_max_seq_len = fuse_max_seq_len
self.modules_to_not_convert = modules_to_not_convert self.modules_to_not_convert = modules_to_not_convert
self.exllama_config = exllama_config
self.modules_to_fuse = modules_to_fuse self.modules_to_fuse = modules_to_fuse
if do_fuse is None: if do_fuse is None:
...@@ -667,9 +676,9 @@ class AwqConfig(QuantizationConfigMixin): ...@@ -667,9 +676,9 @@ class AwqConfig(QuantizationConfigMixin):
) )
self.version = AWQLinearVersion.from_str(self.version) self.version = AWQLinearVersion.from_str(self.version)
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]: if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA]:
raise ValueError( raise ValueError(
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}" f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA] - not recognized version {self.version}"
) )
if self.backend == AwqBackendPackingMethod.LLMAWQ: if self.backend == AwqBackendPackingMethod.LLMAWQ:
...@@ -724,9 +733,34 @@ class AwqConfig(QuantizationConfigMixin): ...@@ -724,9 +733,34 @@ class AwqConfig(QuantizationConfigMixin):
f"Required fields are missing in the fusing mapping, required fields are {required_keys}" f"Required fields are missing in the fusing mapping, required fields are {required_keys}"
) )
if self.version == AWQLinearVersion.EXLLAMA:
awq_version_supports_exllama = False
MIN_AWQ_VERSION = "0.2.0"
if is_auto_awq_available():
awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
MIN_AWQ_VERSION
)
if not awq_version_supports_exllama:
raise ValueError(
f"You current version of `autoawq` does not support exllama backend, "
f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)
if self.exllama_config is None:
self.exllama_config = {"version": ExllamaVersion.TWO, "max_input_len": 2048, "max_batch_size": 8}
else:
if "version" not in self.exllama_config:
raise ValueError("`exllama_config` needs to have a `version` key.")
elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
exllama_version = self.exllama_config["version"]
raise ValueError(
f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
)
def get_loading_attributes(self): def get_loading_attributes(self):
attibutes_dict = copy.deepcopy(self.__dict__) attibutes_dict = copy.deepcopy(self.__dict__)
loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"] loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict return loading_attibutes_dict
......
...@@ -192,6 +192,20 @@ class AwqTest(unittest.TestCase): ...@@ -192,6 +192,20 @@ class AwqTest(unittest.TestCase):
output = quantized_model.generate(**input_ids, max_new_tokens=40) output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16)
def test_quantized_model_exllama(self):
"""
Simple test that checks if the quantized model is working properly with exllama backend
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = AwqConfig(version="exllama")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=quantization_config
).to(torch_device)
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_quantized_model_no_device_map(self): def test_quantized_model_no_device_map(self):
""" """
Simple test that checks if the quantized model is working properly Simple test that checks if the quantized model is working properly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment