Unverified Commit 9c3dfa07 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FEAT: Add possibility of skipping modules when quantizing (#248)

parent 78b59d73
import os import os
import json import json
import logging import logging
from typing import Dict from typing import Dict, Optional, List
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from transformers.utils.hub import PushToHubMixin, cached_file from transformers.utils.hub import PushToHubMixin, cached_file
...@@ -13,6 +13,7 @@ class AwqConfig(PushToHubMixin): ...@@ -13,6 +13,7 @@ class AwqConfig(PushToHubMixin):
w_bit: int = field(default=4) w_bit: int = field(default=4)
version: str = field(default="GEMM") version: str = field(default="GEMM")
config_file_name = "quant_config.json" config_file_name = "quant_config.json"
modules_to_not_convert: Optional[List] = None
def save_pretrained(self, save_dir: str, **kwargs): def save_pretrained(self, save_dir: str, **kwargs):
logging.warning( logging.warning(
...@@ -76,7 +77,8 @@ class AwqConfig(PushToHubMixin): ...@@ -76,7 +77,8 @@ class AwqConfig(PushToHubMixin):
"zero_point": self.zero_point, "zero_point": self.zero_point,
"q_group_size": self.q_group_size, "q_group_size": self.q_group_size,
"w_bit": self.w_bit, "w_bit": self.w_bit,
"version": self.version "version": self.version,
"modules_to_not_convert": self.modules_to_not_convert,
} }
def to_transformers_dict(self): def to_transformers_dict(self):
...@@ -86,4 +88,5 @@ class AwqConfig(PushToHubMixin): ...@@ -86,4 +88,5 @@ class AwqConfig(PushToHubMixin):
"group_size": self.q_group_size, "group_size": self.q_group_size,
"bits": self.w_bit, "bits": self.w_bit,
"version": self.version.lower(), "version": self.version.lower(),
"modules_to_not_convert": self.modules_to_not_convert,
} }
...@@ -49,12 +49,12 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -49,12 +49,12 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval", calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text", duo_scaling=True): split="train", text_column="text", duo_scaling=True, modules_to_not_convert=None):
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
quantizer = AwqQuantizer( quantizer = AwqQuantizer(
self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size, self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
self.quant_config.version, calib_data, split, text_column, duo_scaling self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert
) )
quantizer.quantize() quantizer.quantize()
self.is_quantized = True self.is_quantized = True
......
...@@ -14,7 +14,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, ...@@ -14,7 +14,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class AwqQuantizer: class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, duo_scaling) -> None: calib_data, split, text_column, duo_scaling, modules_to_not_convert=None) -> None:
self.awq_model = awq_model self.awq_model = awq_model
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -25,6 +25,7 @@ class AwqQuantizer: ...@@ -25,6 +25,7 @@ class AwqQuantizer:
self.split = split self.split = split
self.text_column = text_column self.text_column = text_column
self.duo_scaling = duo_scaling self.duo_scaling = duo_scaling
self.modules_to_not_convert = modules_to_not_convert if modules_to_not_convert is not None else []
self.modules, self.module_kwargs, self.inps = self.init_quant() self.modules, self.module_kwargs, self.inps = self.init_quant()
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False): def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
...@@ -68,6 +69,13 @@ class AwqQuantizer: ...@@ -68,6 +69,13 @@ class AwqQuantizer:
return w return w
def _exclude_layers_to_not_quantize(self, linear_layers):
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in self.modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers
def quantize(self): def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"): for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device # Move module and inputs to correct device
...@@ -80,6 +88,10 @@ class AwqQuantizer: ...@@ -80,6 +88,10 @@ class AwqQuantizer:
# [STEP 1]: Get layer, extract linear modules, extract input features # [STEP 1]: Get layer, extract linear modules, extract input features
named_linears = get_named_linears(self.modules[i]) named_linears = get_named_linears(self.modules[i])
# Filter out the linear layers we don't want to exclude
named_linears = self._exclude_layers_to_not_quantize(named_linears)
input_feat = self._get_input_feat(self.modules[i], named_linears) input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory() clear_memory()
......
...@@ -53,6 +53,8 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -53,6 +53,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
# apply the scaling to input feat if given; prepare it for clipping # apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None: if input_feat_dict is not None:
for layer_name in layer_names: for layer_name in layer_names:
# Skip the modules that are not quantized
if layer_name in input_feat_dict:
inp = input_feat_dict[layer_name] inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device)) inp.div_(scales.view(1, -1).to(inp.device))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment