Commit 2082197d authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor Llama quant attention

parent 560fbe59
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.modules import make_quant_norm, make_quant_attn, make_fused_mlp from awq.modules import make_quant_norm, make_fused_mlp
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -7,10 +7,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -7,10 +7,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(awq_model): def fuse_layers(awq_model: BaseAWQForCausalLM):
make_quant_attn(awq_model, awq_model.device) fuser = LlamaFuser(awq_model)
make_quant_norm(awq_model) fuser.fuse_attention()
make_fused_mlp(awq_model) make_quant_norm(awq_model)#fuser.fuse_rmsnorm()
make_fused_mlp(awq_model)#fuser.fuse_mlp()
@staticmethod @staticmethod
def get_model_layers(model: LlamaForCausalLM): def get_model_layers(model: LlamaForCausalLM):
...@@ -63,4 +64,68 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -63,4 +64,68 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
inp=input_feat['mlp.down_proj'], inp=input_feat['mlp.down_proj'],
)) ))
return layers return layers
\ No newline at end of file
import torch
from typing import List, Tuple
from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name
from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm
class LlamaFuser:
def __init__(self, awq_model: BaseAWQForCausalLM):
self.awq_model = awq_model
self.model = awq_model.model
self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaAttention)
]
self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaRMSNorm)
]
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: WQLinear = self._fuse_qkv(module)
attn = QuantLlamaAttention(
module.hidden_size,
module.num_heads,
qkv_layer,
module.o_proj,
qkv_layer.qweight.device,
self.awq_model.model.config.max_new_tokens
)
set_module_name(self.model, name, attn)
def _fuse_qkv(self, module: LlamaAttention):
# get qkv and bias
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module
qkv_layer = WQLinear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
q_proj.qweight.device
)
# replace buffers with real weights
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def fuse_rmsnorm(self):
pass
def fuse_mlp(self):
pass
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb
from awq.quantize.qmodule import WQLinear
import awq_inference_engine import awq_inference_engine
from torch.nn import functional as F
class QuantLlamaRotaryEmbedding(nn.Module): class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
...@@ -64,7 +59,8 @@ class QuantLlamaAttention(nn.Module): ...@@ -64,7 +59,8 @@ class QuantLlamaAttention(nn.Module):
num_heads, num_heads,
qkv_proj, qkv_proj,
o_proj, o_proj,
dev dev,
max_new_tokens
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -76,7 +72,7 @@ class QuantLlamaAttention(nn.Module): ...@@ -76,7 +72,7 @@ class QuantLlamaAttention(nn.Module):
f" and `num_heads`: {num_heads}).") f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj self.qkv_proj = qkv_proj
self.o_proj = o_proj self.o_proj = o_proj
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=2048, device = dev) self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -101,7 +97,7 @@ class QuantLlamaAttention(nn.Module): ...@@ -101,7 +97,7 @@ class QuantLlamaAttention(nn.Module):
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
value_states = value_states.to("cuda:0") value_states = value_states.to(key_states.device)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
...@@ -125,43 +121,3 @@ class QuantLlamaAttention(nn.Module): ...@@ -125,43 +121,3 @@ class QuantLlamaAttention(nn.Module):
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
def make_quant_attn(model, dev):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue
q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
g_idx = None
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qkv_layer = WQLinear(q_proj.w_bit, q_proj.group_size, q_proj.in_features, q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.bias is not None, q_proj.qweight.device)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.bias = bias
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, dev)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, attn)
...@@ -41,3 +41,15 @@ def simple_dispatch_model(model, device_map): ...@@ -41,3 +41,15 @@ def simple_dispatch_model(model, device_map):
model.hf_device_map = device_map model.hf_device_map = device_map
return model return model
def set_module_name(model, name, value):
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment