Unverified Commit de6e0db1 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

[awq] replace scale when we have GELU (#30074)

* fix awq test

* style

* add log

* new fix

* style

* only modifying impacted model in the end

* rename function
parent e0c3cee1
......@@ -21,6 +21,7 @@ _import_structure = {
"awq": [
"fuse_awq_modules",
"post_init_awq_exllama_modules",
"replace_quantization_scales",
"replace_with_awq_linear",
],
"bitsandbytes": [
......@@ -92,6 +93,7 @@ if TYPE_CHECKING:
from .awq import (
fuse_awq_modules,
post_init_awq_exllama_modules,
replace_quantization_scales,
replace_with_awq_linear,
)
from .bitsandbytes import (
......
......@@ -14,7 +14,7 @@
"AWQ (Activation aware Weight Quantization) integration file"
from ..activations import ACT2FN
from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_awq_available, is_torch_available
from ..utils import is_auto_awq_available, is_torch_available, logging
from ..utils.quantization_config import (
AwqBackendPackingMethod,
AwqConfig,
......@@ -27,6 +27,7 @@ if is_torch_available():
import torch
import torch.nn as nn
logger = logging.get_logger(__name__)
AWQ_FUSED_MAPPINGS = {
"mistral": {
......@@ -56,6 +57,34 @@ AWQ_FUSED_MAPPINGS = {
},
}
AWQ_SCALES_MAPPINGS = {
"starcoder2": {"act": "act", "layer_before_act": "c_fc"},
"RefinedWebModel": {"act": "act", "layer_before_act": "dense_h_to_4h"},
"falcon": {"act": "act", "layer_before_act": "dense_h_to_4h"},
"mpt": {"act": "act", "layer_before_act": "up_proj"},
"gptj": {"act": "act", "layer_before_act": "fc_in"},
"gpt_neox": {"act": "act", "layer_before_act": "dense_h_to_4h"},
"gpt_bigcode": {"act": "act", "layer_before_act": "c_fc"},
"bloom": {"act": "gelu_impl", "layer_before_act": "dense_h_to_4h"},
}
def replace_quantization_scales(model, model_type):
from awq.modules.act import ScaledActivation
if model_type not in AWQ_SCALES_MAPPINGS:
return model
for name, module in model.named_children():
act_name = AWQ_SCALES_MAPPINGS[model_type]["act"]
layer_before_act_name = AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"]
if name == act_name and hasattr(model, layer_before_act_name):
layer_before_act = getattr(model, AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"])
size = layer_before_act.out_features
scale_like = torch.ones(size)
model._modules[name] = ScaledActivation(module, scale_like)
_ = replace_quantization_scales(module, model_type)
return model
def replace_with_awq_linear(
model,
......
......@@ -75,7 +75,7 @@ class AwqQuantizer(HfQuantizer):
return torch_dtype
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from ..integrations import get_keys_to_not_convert, replace_with_awq_linear
from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear
self.modules_to_not_convert = get_keys_to_not_convert(model)
......@@ -86,6 +86,8 @@ class AwqQuantizer(HfQuantizer):
model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
)
model = replace_quantization_scales(model, model.config.model_type)
if not has_been_replaced:
logger.warning(
"You are loading an AWQ model but no linear modules were found in your model."
......
......@@ -471,3 +471,22 @@ class AwqFusedTest(unittest.TestCase):
outputs = model.generate(**inputs, max_new_tokens=12)
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_MIXTRAL)
@slow
@require_torch_gpu
@require_auto_awq
@require_accelerate
class AwqScaleTest(unittest.TestCase):
model_name = "TechxGenus/starcoder2-3b-AWQ"
def test_load_quantized_model(self):
from awq.modules.act import ScaledActivation
"""
Simple test that checks if the scales have been replaced in the quantized model
"""
quantized_model = AutoModelForCausalLM.from_pretrained(
"TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda"
)
self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment