Commit 42655843 authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove CustomQuantAttention class

parent bbe1d46a
......@@ -185,8 +185,8 @@ if __name__ == '__main__':
run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
elif args.entry_type == 'speed':
# if args.batch_size > 1 and not args.disable_fused_layers:
# raise Exception('Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).')
if args.batch_size > 1 and not args.disable_fused_layers:
raise Exception('Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).')
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers)
else:
......
......@@ -71,7 +71,7 @@ from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantLlamaMLP
from awq.modules.fused_norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention, CustomQuantLlamaAttention
from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser:
......@@ -96,7 +96,7 @@ class LlamaFuser:
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: WQLinear = self._fuse_qkv(module)
attn = CustomQuantLlamaAttention(
attn = QuantLlamaAttention(
module.hidden_size,
module.num_heads,
qkv_layer,
......
......@@ -121,85 +121,3 @@ class QuantLlamaAttention(nn.Module):
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class CustomQuantLlamaAttention(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
qkv_proj,
o_proj,
dev,
max_new_tokens
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask,
position_ids,
past_key_value,
output_attentions: bool = False,
use_cache: bool = False,
):
# qkv proj
qkv_states = self.qkv_proj(hidden_states)
# extract q,k,v
bsz, q_len, _ = hidden_states.size()
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# rotary embedding
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
# cache ops
is_causal = past_key_value is None
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
past_key_value = (key_states, value_states) if use_cache else None
# multi-head masked attention
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None if is_causal else attention_mask,
is_causal=is_causal
)
# reshape output
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# out projection
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment