Commit d3625d1c authored by Casper Hansen's avatar Casper Hansen
Browse files

Set batch size for attention shapes

parent fdff74d6
import os
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -114,7 +115,8 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -114,7 +115,8 @@ class QuantLlamaRotaryEmbedding(nn.Module):
return query, key return query, key
class QuantAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len, use_alibi=False, attention_shapes=None): def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.n_local_heads = num_heads self.n_local_heads = num_heads
...@@ -123,7 +125,7 @@ class QuantAttentionFused(nn.Module): ...@@ -123,7 +125,7 @@ class QuantAttentionFused(nn.Module):
self.o_proj = o_proj self.o_proj = o_proj
self.start_pos = 0 self.start_pos = 0
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = 1 self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = attention_shapes if attention_shapes is not None else { self.attention_shapes = attention_shapes if attention_shapes is not None else {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,), "cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,),
...@@ -170,6 +172,11 @@ class QuantAttentionFused(nn.Module): ...@@ -170,6 +172,11 @@ class QuantAttentionFused(nn.Module):
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
): ):
bsz, seqlen, _ = hidden_states.shape bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size:
raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
)
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......
import os
import torch.nn as nn import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused from awq.modules.fused.attn import QuantAttentionFused
...@@ -34,7 +35,7 @@ class FalconDecoderLayer(nn.Module): ...@@ -34,7 +35,7 @@ class FalconDecoderLayer(nn.Module):
self.n_heads = n_heads self.n_heads = n_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch self.new_decoder_arch = new_decoder_arch
attention_shapes = self._get_attention_shapes(1, n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch) attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads, new_decoder_arch)
# TODO: Falcon has ALiBi implemented but which model uses it? # TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
...@@ -51,7 +52,9 @@ class FalconDecoderLayer(nn.Module): ...@@ -51,7 +52,9 @@ class FalconDecoderLayer(nn.Module):
self.mlp = mlp self.mlp = mlp
def _get_attention_shapes(self, batch_size, n_heads, max_seq_len, head_dim, new_decoder_arch): def _get_attention_shapes(self, n_heads, max_seq_len, head_dim, new_decoder_arch):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if new_decoder_arch: if new_decoder_arch:
kv_heads = 8 kv_heads = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment