Unverified Commit 0ab278ca authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Remove unnecessary copies in flash attn backend (#5138)

parent 7a64d24a
...@@ -6,4 +6,4 @@ ray >= 2.9 ...@@ -6,4 +6,4 @@ ray >= 2.9
nvidia-ml-py # for pynvml package nvidia-ml-py # for pynvml package
torch == 2.3.0 torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0
...@@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
out = flash_attn_varlen_func( flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -329,14 +329,13 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -329,14 +329,13 @@ class FlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
out=output[:num_prefill_tokens],
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
output[:num_prefill_tokens] = flash_attn_varlen_func( flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
...@@ -348,11 +347,12 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -348,11 +347,12 @@ class FlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables, block_table=prefill_meta.block_tables,
out=output[:num_prefill_tokens],
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
output[num_prefill_tokens:] = flash_attn_with_kvcache( flash_attn_with_kvcache(
decode_query.unsqueeze(1), decode_query.unsqueeze(1),
key_cache, key_cache,
value_cache, value_cache,
...@@ -361,7 +361,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -361,7 +361,8 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
).squeeze(1) out=output[num_prefill_tokens:].unsqueeze(1),
)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment