Commit d3625d1c authored by Casper Hansen's avatar Casper Hansen
Browse files

Set batch size for attention shapes

parent fdff74d6
import os
import math
import torch
import torch.nn as nn
......@@ -114,7 +115,8 @@ class QuantLlamaRotaryEmbedding(nn.Module):
return query, key
class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len, use_alibi=False, attention_shapes=None):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None):
super().__init__()
self.hidden_size = hidden_size
self.n_local_heads = num_heads
......@@ -123,7 +125,7 @@ class QuantAttentionFused(nn.Module):
self.o_proj = o_proj
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = 1
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = attention_shapes if attention_shapes is not None else {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,),
......@@ -170,6 +172,11 @@ class QuantAttentionFused(nn.Module):
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
):
bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size:
raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
)
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......
import os
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
......@@ -34,7 +35,7 @@ class FalconDecoderLayer(nn.Module):
self.n_heads = n_heads
self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch
attention_shapes = self._get_attention_shapes(1, n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(
......@@ -51,7 +52,9 @@ class FalconDecoderLayer(nn.Module):
self.mlp = mlp
def _get_attention_shapes(self, batch_size, n_heads, max_seq_len, head_dim, new_decoder_arch):
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim, new_decoder_arch):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if new_decoder_arch:
kv_heads = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment