Unverified Commit d3dbe179 authored by pppppM's avatar pppppM Committed by GitHub
Browse files

[Feature] Support AWQ (#108)

* support kv cache offload

* add dataloader docstring

* complete gitignore

* refactor collect mod fn

* add calibration

* fix lint

* add observers and quantizers

* fix lints

* add global available mixin

* fix lints

* split batch inference

* support smoothquant and awq

* update export kv scales

* fix lints

* fix some bugs

* update weight only usage

* update usage

* auto mapping and support smooth internlm

* trust remote code

* fix num head key error

* fix bias error

* align shape and pack order with llm-awq

* modified according to LZHgrla's comments.

* update gitignore

* fix kv qparams export error

* update usage

* decouple calibrate and awq

* update docstrings

* update api name

* update readme

* update readme

* update readme

* update readme

* update kv_qparams and readme

* fix typos
parent 0d9c6c9d
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Union
from torch import nn
class GlobalAvailMixin:
"""Mixin class to make instances globally available."""
_instances: Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']] = {
'default': {}
}
def global_available(self,
key: Union[str, nn.Module] = 'default',
group: str = 'default') -> None:
"""Make the instance globally available.
Args:
key (Union[str, nn.Module], optional): Key to save the instance.
Defaults to 'default'.
group (str, optional): Group to save the instance.
Defaults to 'default'.
"""
self._save_instance(self, key, group)
@classmethod
def _save_instance(cls,
instance: 'GlobalAvailMixin',
key: Union[str, nn.Module] = 'default',
group: str = 'default') -> None:
"""Save the instance.
Args:
instance (GlobalAvailMixin): Instance to save.
key (Union[str, nn.Module], optional): Key to save the instance.
Defaults to 'default'.
group (str, optional): Group to save the instance.
Defaults to 'default'.
"""
if group not in cls._instances:
assert isinstance(group, str)
cls._instances[group] = {}
cls._instances[group][key] = instance
@classmethod
def find(cls,
key: Union[str, nn.Module] = 'default',
group: str = 'default') -> Union[None, 'GlobalAvailMixin']:
"""Find an instance by its key and group.
Args:
key (Union[str, nn.Module], optional): Key of the instance.
Defaults to 'default'.
group (str, optional): Group of the instance.
Defaults to 'default'.
Returns:
Union[None, GlobalAvailMixin]: The found instance, or None if
it does not exist.
"""
return cls._instances.get(group, {}).get(key)
@classmethod
def find_group(
cls,
group: str) -> Dict[Union[str, nn.Module], 'GlobalAvailMixin']:
"""Find all instances in a group.
Args:
group (str): Group of the instances.
Returns:
Dict[Union[str, nn.Module], GlobalAvailMixin]: All instances in
the group.
"""
return cls._instances.get(group, {})
@classmethod
def instances(
cls) -> Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']]:
"""Get all instances."""
return cls._instances
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import re
import warnings
from contextlib import contextmanager
from functools import partial
from typing import List
import torch
from torch import nn
from lmdeploy.lite.defaults import KV_CACHE_SIGNATURE, OFFLOAD_MOD
def extract_return_values(module: nn.Module) -> List[str]:
"""Extracts return values from given module's forward method.
Args:
module (nn.Module): Module to inspect
Returns:
list[str]: List of return values
"""
last_line = inspect.getsource(module.forward).rstrip('\n').split('\n')[-1]
pattern = r'return ([\w\s,]+)'
match = re.search(pattern, last_line)
if match:
return_values = match.group(1).split(',')
return [value.strip() for value in return_values]
else:
return []
def find_kv_cache_idx(module: nn.Module) -> int:
"""Finds index of kv cache signature in module's forward parameters."""
signatures = list(inspect.signature(module.forward).parameters.keys())
if KV_CACHE_SIGNATURE not in signatures:
raise ValueError(f'{KV_CACHE_SIGNATURE} not in signatures of '
f'{type(module)} forward.')
return signatures.index(KV_CACHE_SIGNATURE)
def find_modules_by_return_value(model: nn.Module,
value: str) -> List[nn.Module]:
"""Finds modules in model that return given value.
Args:
model (nn.Module): Model to inspect
value (str): Return value to search for
Returns:
list[nn.Module]: List of matching modules
Raises:
ValueError: If no matching modules found
"""
modules = []
for name, module in model.named_modules():
returns = extract_return_values(module)
if value in returns:
print(f'Found {name} returning {value}')
modules.append(module)
if not modules:
error_msg = f'No modules found returning {value}. '
error_msg += 'Please check if the default KV_CACHE_SIGNATURE '
error_msg += f"'{KV_CACHE_SIGNATURE}' matches what is used in your "
error_msg += 'model code. If not, you can modify KV_CACHE_SIGNATURE '
error_msg += 'in `lmdeploy.lite.defaults`.'
raise ValueError(error_msg)
return modules
@contextmanager
def memory_efficient_inference(model: nn.Module,
target=(nn.Linear, ),
device='cuda'):
"""Context manager for memory-efficient inference on specified modules of a
PyTorch model.
def offload_kv_cache(model: nn.Module, device: str = 'cuda') -> None:
"""Offloads kv cache to given device during forward pass.
Args:
model (nn.Module): The model to be used for inference.
target (tuple): A tuple containing the target module classes to move to
GPU during forward pass.
device (str): The device ('cpu' or 'cuda') where the model will be
moved during inference.
model (nn.Module): Model for inference
device (str): Device to offload to
Yields:
None
"""
modules = find_modules_by_return_value(model, KV_CACHE_SIGNATURE)
original_forwards = {mod: mod.forward for mod in modules}
input_idxs = {mod: find_kv_cache_idx(mod) for mod in modules}
output_idxs = {
mod: extract_return_values(mod).index(KV_CACHE_SIGNATURE)
for mod in modules
}
def wrap_forward(module, *args, **kwargs):
idx = input_idxs[module]
if idx >= len(args):
# kv cache in kwargs
if KV_CACHE_SIGNATURE in kwargs:
if kwargs[KV_CACHE_SIGNATURE]:
kwargs[KV_CACHE_SIGNATURE] = kwargs[KV_CACHE_SIGNATURE].to(
device)
else:
raise ValueError(f'No kv cache input found at index {idx}')
else:
# kv cache in args
args = list(args)
args[idx] = args[idx].to(device)
args = tuple(args)
result = original_forwards[module](*args, **kwargs)
result = list(result)
idx = output_idxs[module]
# Move kv cache outputs back to CPU
key = result[idx][0].to('cpu')
value = result[idx][1].to('cpu')
torch.cuda.empty_cache()
result[idx] = (key, value)
result = tuple(result)
Example:
with memory_efficient_inference(model, target=nn.Linear, device='cuda'):
output = model(input)
return result
try:
for module in modules:
original_forwards[module] = module.forward
module.forward = partial(wrap_forward, module)
yield
finally:
for module in modules:
module.forward = original_forwards[module]
del original_forwards[module]
@contextmanager
def offload_weights(model: nn.Module, device: str = 'cuda') -> None:
"""Offloads specified modules to given device during forward pass.
Args:
model (nn.Module): Model for inference
device (str): Device to offload to
Yields:
None
"""
def _before_forward_hook(m, input):
m.to(device)
target_modules = OFFLOAD_MOD
def _after_forward_hook(m, input, output):
m.to('cpu')
def before_forward(module: nn.Module, inp: torch.Tensor):
module.to(device)
def after_forward(module: nn.Module, inp: torch.Tensor, out: torch.Tensor):
module.to('cpu')
torch.cuda.empty_cache()
def _to_device(m, spec_modules, dev):
if len(spec_modules) == 0:
if len(spec_modules) == 0 or len(list(m.children())) == 0:
m.to(dev)
return
......@@ -44,24 +173,61 @@ def memory_efficient_inference(model: nn.Module,
child.to('cpu')
else:
_to_device(child, spec_modules, dev)
m.to(dev)
# m.to(dev)
warnings.warn('By default, offloading will be done on '
'`nn.Linear`. You can add modules which want offload to '
'the `lmdeploy.lite.defaults.OFFLOAD_MOD`.')
target = OFFLOAD_MOD
_to_device(model, target, device)
# enter
hook_handles = []
handles = []
for module in model.modules():
if isinstance(module, target):
before_h = module.register_forward_pre_hook(_before_forward_hook)
after_h = module.register_forward_hook(_after_forward_hook)
hook_handles.append(before_h)
hook_handles.append(after_h)
if isinstance(module, target_modules):
handle1 = module.register_forward_pre_hook(before_forward)
handle2 = module.register_forward_hook(after_forward)
handles.extend([handle1, handle2])
with torch.inference_mode():
try:
yield
# exit
for h in hook_handles:
h.remove()
finally:
for handle in handles:
handle.remove()
model.to('cpu')
torch.cuda.empty_cache()
@contextmanager
def memory_efficient_inference(model: nn.Module,
offload: bool = True,
device: str = 'cuda') -> None:
"""Memory efficient inference context manager.
Moves model to device for inference, with option to offload
specific modules.
Args:
model (nn.Module): Model for inference
offload (bool): Whether to offload modules
device (str): Device for inference
Yields:
None
"""
if offload:
warnings.warn('Using offload mode - modules defined in OFFLOAD_MOD '
'will be moved to GPU during forward pass only.')
warnings.warn(
'Using offload mode will incur performance penalty due to '
'frequent CPU-GPU data transfers.')
with torch.inference_mode():
with offload_kv_cache(model, device):
with offload_weights(model, device):
yield
else:
model.to(device)
with torch.inference_mode():
yield
# Copyright (c) OpenMMLab. All rights reserved.
from .linear import WeightOnlyQLinear
__all__ = ['WeightOnlyQLinear']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Type, TypeVar
import torch
from torch import nn
class WeightOnlyQLinear(nn.Module):
"""This class implements weight only quantization linear.
Args:
w_bit (int): number of bits for quantization.
symmetry (bool): If true, use symmetric quantization,
otherwise use asymmetric quantization.
group_size (int): size of the quantization group.
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (Tensor, optional): Defaults to None.
"""
def __init__(self,
w_bit: int,
symmetry: bool,
group_size: int,
in_features: int,
out_features: int,
bias: Optional[torch.Tensor] = None) -> None:
super().__init__()
if w_bit not in [2, 4, 8]:
raise NotImplementedError('Only 2,4,8 bit are supported for now.')
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
w_pack_oc = out_features // (32 // self.w_bit)
w_inc = in_features
weight = torch.zeros((w_inc, w_pack_oc), dtype=torch.int32)
self.register_buffer('qweight', weight)
if bias:
self.register_buffer('bias', torch.zeros(out_features))
else:
self.bias = None
s_inc = in_features // self.group_size
s_oc = out_features
scales = torch.zeros((s_inc, s_oc), dtype=torch.float16)
self.register_buffer('scales', scales)
if not symmetry:
z_inc = in_features // self.group_size
z_oc = out_features // (32 // self.w_bit)
zeros = torch.zeros((z_inc, z_oc), dtype=torch.int32)
self.register_buffer('qzeros', zeros)
else:
self.qzeros = None
@classmethod
def from_linear(cls: Type['WeightOnlyQLinear'],
linear: nn.Linear,
quantizer: TypeVar('Quantizer'),
awq_layout: bool = True) -> 'WeightOnlyQLinear':
"""Create a WeightOnlyQLinear object from a PyTorch Linear object.
Args:
linear (nn.Linear): PyTorch Linear object.
quantizer (Quantizer): Object that handles quantization.
awq_layout (bool): AWQ layout. Defaults to True.
Returns:
WeightOnlyQLinear: A WeightOnlyQLinear object.
"""
device = linear.weight.device
w_bit = quantizer.bits
pack_num = 32 // w_bit
if awq_layout:
assert w_bit == 4
pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
else:
pack_order = torch.arange(pack_num)
group_size = quantizer.group_size
symmetry = quantizer.symmetry
in_features = linear.in_features
out_features = linear.out_features
bias = False if linear.bias is None else True
qlinear = cls(w_bit, symmetry, group_size, in_features, out_features,
bias)
qlinear.bias = linear.bias
qparams = quantizer.calculate_qparams(linear.weight)
i32_w = quantizer.quant(linear.weight, qparams, real=True)
i32_w = i32_w.t().contiguous()
pack_int_w = torch.zeros_like(qlinear.qweight).to(device)
for col in range(pack_int_w.shape[1]):
for i in range(pack_num):
pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
pack_int_w[:, col] |= pack_int_w_col << (i * w_bit)
qlinear.qweight = pack_int_w
qlinear.scales = qparams.scales.squeeze(-1).t().contiguous()
if qparams.zero_points is not None:
zeros = qparams.zero_points.to(torch.int32).to(device)
zeros = zeros.squeeze(-1).t().contiguous()
pack_int_zeros = torch.zeros_like(qlinear.qzeros).to(device)
for col in range(pack_int_zeros.shape[1]):
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + pack_order[i]]
pack_int_zeros[:, col] |= qzero_col << (i * w_bit)
qlinear.qzeros = pack_int_zeros
qlinear.to('cpu')
return qlinear
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment