Unverified Commit 5db86ec5 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

New logic for passing past_key_value (#177)

parent 63e3fd83
......@@ -215,5 +215,7 @@ class QuantAttentionFused(nn.Module):
self.start_pos += seqlen
# past_key_value is replaced with cache_v, cache_k, returning empty data
past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
# we pass a dummy past kv cache for transformers to be able to retrieve the correct info
# about past key length
past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
return attn_output, attention_weight, past_key_value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment