Unverified Commit de6e0db1 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

[awq] replace scale when we have GELU (#30074)

* fix awq test

* style

* add log

* new fix

* style

* only modifying impacted model in the end

* rename function
parent e0c3cee1
...@@ -21,6 +21,7 @@ _import_structure = { ...@@ -21,6 +21,7 @@ _import_structure = {
"awq": [ "awq": [
"fuse_awq_modules", "fuse_awq_modules",
"post_init_awq_exllama_modules", "post_init_awq_exllama_modules",
"replace_quantization_scales",
"replace_with_awq_linear", "replace_with_awq_linear",
], ],
"bitsandbytes": [ "bitsandbytes": [
...@@ -92,6 +93,7 @@ if TYPE_CHECKING: ...@@ -92,6 +93,7 @@ if TYPE_CHECKING:
from .awq import ( from .awq import (
fuse_awq_modules, fuse_awq_modules,
post_init_awq_exllama_modules, post_init_awq_exllama_modules,
replace_quantization_scales,
replace_with_awq_linear, replace_with_awq_linear,
) )
from .bitsandbytes import ( from .bitsandbytes import (
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"AWQ (Activation aware Weight Quantization) integration file" "AWQ (Activation aware Weight Quantization) integration file"
from ..activations import ACT2FN from ..activations import ACT2FN
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_awq_available, is_torch_available from ..utils import is_auto_awq_available, is_torch_available, logging
from ..utils.quantization_config import ( from ..utils.quantization_config import (
AwqBackendPackingMethod, AwqBackendPackingMethod,
AwqConfig, AwqConfig,
...@@ -27,6 +27,7 @@ if is_torch_available(): ...@@ -27,6 +27,7 @@ if is_torch_available():
import torch import torch
import torch.nn as nn import torch.nn as nn
logger = logging.get_logger(__name__)
AWQ_FUSED_MAPPINGS = { AWQ_FUSED_MAPPINGS = {
"mistral": { "mistral": {
...@@ -56,6 +57,34 @@ AWQ_FUSED_MAPPINGS = { ...@@ -56,6 +57,34 @@ AWQ_FUSED_MAPPINGS = {
}, },
} }
AWQ_SCALES_MAPPINGS = {
"starcoder2": {"act": "act", "layer_before_act": "c_fc"},
"RefinedWebModel": {"act": "act", "layer_before_act": "dense_h_to_4h"},
"falcon": {"act": "act", "layer_before_act": "dense_h_to_4h"},
"mpt": {"act": "act", "layer_before_act": "up_proj"},
"gptj": {"act": "act", "layer_before_act": "fc_in"},
"gpt_neox": {"act": "act", "layer_before_act": "dense_h_to_4h"},
"gpt_bigcode": {"act": "act", "layer_before_act": "c_fc"},
"bloom": {"act": "gelu_impl", "layer_before_act": "dense_h_to_4h"},
}
def replace_quantization_scales(model, model_type):
from awq.modules.act import ScaledActivation
if model_type not in AWQ_SCALES_MAPPINGS:
return model
for name, module in model.named_children():
act_name = AWQ_SCALES_MAPPINGS[model_type]["act"]
layer_before_act_name = AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"]
if name == act_name and hasattr(model, layer_before_act_name):
layer_before_act = getattr(model, AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"])
size = layer_before_act.out_features
scale_like = torch.ones(size)
model._modules[name] = ScaledActivation(module, scale_like)
_ = replace_quantization_scales(module, model_type)
return model
def replace_with_awq_linear( def replace_with_awq_linear(
model, model,
......
...@@ -75,7 +75,7 @@ class AwqQuantizer(HfQuantizer): ...@@ -75,7 +75,7 @@ class AwqQuantizer(HfQuantizer):
return torch_dtype return torch_dtype
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from ..integrations import get_keys_to_not_convert, replace_with_awq_linear from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear
self.modules_to_not_convert = get_keys_to_not_convert(model) self.modules_to_not_convert = get_keys_to_not_convert(model)
...@@ -86,6 +86,8 @@ class AwqQuantizer(HfQuantizer): ...@@ -86,6 +86,8 @@ class AwqQuantizer(HfQuantizer):
model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
) )
model = replace_quantization_scales(model, model.config.model_type)
if not has_been_replaced: if not has_been_replaced:
logger.warning( logger.warning(
"You are loading an AWQ model but no linear modules were found in your model." "You are loading an AWQ model but no linear modules were found in your model."
......
...@@ -471,3 +471,22 @@ class AwqFusedTest(unittest.TestCase): ...@@ -471,3 +471,22 @@ class AwqFusedTest(unittest.TestCase):
outputs = model.generate(**inputs, max_new_tokens=12) outputs = model.generate(**inputs, max_new_tokens=12)
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL) self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL)
@slow
@require_torch_gpu
@require_auto_awq
@require_accelerate
class AwqScaleTest(unittest.TestCase):
model_name = "TechxGenus/starcoder2-3b-AWQ"
def test_load_quantized_model(self):
from awq.modules.act import ScaledActivation
"""
Simple test that checks if the scales have been replaced in the quantized model
"""
quantized_model = AutoModelForCausalLM.from_pretrained(
"TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda"
)
self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment