Commit 40e6952a authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove past_key_value (save 2GB VRAM)

parent eccb8f9c
...@@ -206,7 +206,6 @@ class QuantAttentionFused(nn.Module): ...@@ -206,7 +206,6 @@ class QuantAttentionFused(nn.Module):
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups) keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups) values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2) xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2) keys = keys.transpose(1, 2)
values = values.transpose(1, 2) values = values.transpose(1, 2)
...@@ -222,14 +221,10 @@ class QuantAttentionFused(nn.Module): ...@@ -222,14 +221,10 @@ class QuantAttentionFused(nn.Module):
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else: else:
# xq = xq[:, 0, :, :]
# xk = xk[:, 0, :, :]
# xv = xv[:, 0, :, :]
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"]) xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"]) xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"]) xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
past_key_value = (xk, xv) if use_cache else None
attention_weight = ft_inference_engine.single_query_attention( attention_weight = ft_inference_engine.single_query_attention(
xq, # query xq, # query
xk, # key xk, # key
...@@ -252,4 +247,5 @@ class QuantAttentionFused(nn.Module): ...@@ -252,4 +247,5 @@ class QuantAttentionFused(nn.Module):
else: else:
self.start_pos = 0 self.start_pos = 0
return attn_output, attention_weight, past_key_value # past_key_value is replaced with cache_v, cache_k, returning None
return attn_output, attention_weight, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment