Unverified Commit 9c3dfa07 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FEAT: Add possibility of skipping modules when quantizing (#248)

parent 78b59d73
import os
import json
import logging
from typing import Dict
from typing import Dict, Optional, List
from dataclasses import dataclass, field, fields
from transformers.utils.hub import PushToHubMixin, cached_file
......@@ -13,6 +13,7 @@ class AwqConfig(PushToHubMixin):
w_bit: int = field(default=4)
version: str = field(default="GEMM")
config_file_name = "quant_config.json"
modules_to_not_convert: Optional[List] = None
def save_pretrained(self, save_dir: str, **kwargs):
logging.warning(
......@@ -76,7 +77,8 @@ class AwqConfig(PushToHubMixin):
"zero_point": self.zero_point,
"q_group_size": self.q_group_size,
"w_bit": self.w_bit,
"version": self.version
"version": self.version,
"modules_to_not_convert": self.modules_to_not_convert,
}
def to_transformers_dict(self):
......@@ -86,4 +88,5 @@ class AwqConfig(PushToHubMixin):
"group_size": self.q_group_size,
"bits": self.w_bit,
"version": self.version.lower(),
"modules_to_not_convert": self.modules_to_not_convert,
}
......@@ -49,12 +49,12 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text", duo_scaling=True):
split="train", text_column="text", duo_scaling=True, modules_to_not_convert=None):
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
quantizer = AwqQuantizer(
self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
self.quant_config.version, calib_data, split, text_column, duo_scaling
self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert
)
quantizer.quantize()
self.is_quantized = True
......
......@@ -14,7 +14,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, duo_scaling) -> None:
calib_data, split, text_column, duo_scaling, modules_to_not_convert=None) -> None:
self.awq_model = awq_model
self.model = model
self.tokenizer = tokenizer
......@@ -25,6 +25,7 @@ class AwqQuantizer:
self.split = split
self.text_column = text_column
self.duo_scaling = duo_scaling
self.modules_to_not_convert = modules_to_not_convert if modules_to_not_convert is not None else []
self.modules, self.module_kwargs, self.inps = self.init_quant()
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
......@@ -68,6 +69,13 @@ class AwqQuantizer:
return w
def _exclude_layers_to_not_quantize(self, linear_layers):
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in self.modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers
def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
......@@ -80,6 +88,10 @@ class AwqQuantizer:
# [STEP 1]: Get layer, extract linear modules, extract input features
named_linears = get_named_linears(self.modules[i])
# Filter out the linear layers we don't want to exclude
named_linears = self._exclude_layers_to_not_quantize(named_linears)
input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()
......
......@@ -53,6 +53,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
# Skip the modules that are not quantized
if layer_name in input_feat_dict:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment