Unverified Commit c57da6b8 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #75 from casper-hansen/fix_runtime

Fix KV cache shapes error
parents 8eb26eb2 cba9a28c
...@@ -80,12 +80,32 @@ class QuantAttentionFused(nn.Module): ...@@ -80,12 +80,32 @@ class QuantAttentionFused(nn.Module):
self.start_pos = 0 self.start_pos = 0
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len
self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len)
self.cache_v = ( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() )
self.cache_k = ( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half() )
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // n_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True
def _get_attention_shapes(self, attention_shapes, max_seq_len):
if attention_shapes is not None: if attention_shapes is not None:
self.attention_shapes = attention_shapes attention_shapes = attention_shapes
elif self.n_kv_heads == 0: elif self.n_kv_heads == 0:
self.attention_shapes = { attention_shapes = {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,), "cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4 # 8: pack 8 fp16 in FT, if fp32 then use 4
...@@ -104,7 +124,7 @@ class QuantAttentionFused(nn.Module): ...@@ -104,7 +124,7 @@ class QuantAttentionFused(nn.Module):
} }
else: else:
self.attention_shapes = { attention_shapes = {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,), "cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4 # 8: pack 8 fp16 in FT, if fp32 then use 4
...@@ -122,32 +142,11 @@ class QuantAttentionFused(nn.Module): ...@@ -122,32 +142,11 @@ class QuantAttentionFused(nn.Module):
"single_xv_view": (self.n_kv_heads, self.head_dim) "single_xv_view": (self.n_kv_heads, self.head_dim)
} }
self.cache_v = ( return attention_shapes
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
)
self.cache_k = (
torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
)
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // n_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True
def forward( def forward(
self, self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
): ):
bsz, seqlen, _ = hidden_states.shape bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size: if bsz != self.cache_batch_size:
...@@ -155,6 +154,17 @@ class QuantAttentionFused(nn.Module): ...@@ -155,6 +154,17 @@ class QuantAttentionFused(nn.Module):
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. " f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})" f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
) )
if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
# Roll cache to the left
roll_len = self.start_pos
self.cache_v = torch.roll(self.cache_v, shifts=-roll_len, dims=2)
self.cache_k = torch.roll(self.cache_k, shifts=-roll_len, dims=3)
# Zero out the new part
self.cache_v[:, :, -roll_len:, :] = 0
self.cache_k[:, :, :, -roll_len:, :] = 0
self.start_pos = 0
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......
...@@ -60,3 +60,8 @@ def clear_memory(weight=None): ...@@ -60,3 +60,8 @@ def clear_memory(weight=None):
del weight del weight
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
return memory_pct
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment