Unverified Commit 26c0f131 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

Support Page Size > 1 for FA3 (#4832)


Co-authored-by: default avatarQingquan Song <ustcsqq@gmail.com>
Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
parent f9970bd1
...@@ -57,6 +57,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -57,6 +57,7 @@ class FlashAttentionBackend(AttentionBackend):
self.device = model_runner.device self.device = model_runner.device
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.page_size = model_runner.page_size
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata to cache repetitive calculations.""" """Initialize forward metadata to cache repetitive calculations."""
...@@ -78,6 +79,17 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -78,6 +79,17 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
# Precompute strided indices
# [0, page_size, 2 * page_size, ...]
if self.page_size > 1:
self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device
)
metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
if forward_batch.forward_mode == ForwardMode.DECODE: if forward_batch.forward_mode == ForwardMode.DECODE:
# Precompute cumulative sequence lengths # Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
...@@ -132,11 +144,21 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -132,11 +144,21 @@ class FlashAttentionBackend(AttentionBackend):
) )
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1] key_cache, value_cache = kv_cache[0], kv_cache[1]
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
page_table = metadata.page_table
o = flash_attn_with_kvcache( o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache.unsqueeze(1), k_cache=key_cache,
v_cache=value_cache.unsqueeze(1), v_cache=value_cache,
page_table=metadata.page_table, page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32, cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k, cu_seqlens_k_new=metadata.cu_seqlens_k,
...@@ -175,13 +197,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -175,13 +197,11 @@ class FlashAttentionBackend(AttentionBackend):
# Get KV cache # Get KV cache
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1] key_cache, value_cache = kv_cache[0], kv_cache[1]
# Use precomputed metadata # Use precomputed metadata
metadata = self.forward_metadata metadata = self.forward_metadata
# Pre-reshape query tensor # Pre-reshape query tensor
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# Calculate window size (can be moved to metadata if layer properties don't change) # Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive # here is two side inclusive
...@@ -191,11 +211,20 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -191,11 +211,20 @@ class FlashAttentionBackend(AttentionBackend):
else (-1, -1) else (-1, -1)
) )
# Run attention with precomputed values # Run attention with precomputed values
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
page_table = metadata.page_table
o = flash_attn_with_kvcache( o = flash_attn_with_kvcache(
q=q_reshaped, q=q_reshaped,
k_cache=key_cache.unsqueeze(1), k_cache=key_cache,
v_cache=value_cache.unsqueeze(1), v_cache=value_cache,
page_table=metadata.page_table, page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32, cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k, cu_seqlens_k_new=metadata.cu_seqlens_k,
...@@ -207,7 +236,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -207,7 +236,6 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=layer.k_scale, k_descale=layer.k_scale,
v_descale=layer.v_scale, v_descale=layer.v_scale,
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
...@@ -223,7 +251,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -223,7 +251,13 @@ class FlashAttentionBackend(AttentionBackend):
self.decode_cuda_graph_metadata = { self.decode_cuda_graph_metadata = {
# Page table for token mapping (batch_size, max_context_len) # Page table for token mapping (batch_size, max_context_len)
"page_table": torch.zeros( "page_table": torch.zeros(
max_bs, self.max_context_len, dtype=torch.int32, device=self.device max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
), ),
} }
...@@ -252,6 +286,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -252,6 +286,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, : req_pool_indices, :
] ]
if forward_mode == ForwardMode.DECODE: if forward_mode == ForwardMode.DECODE:
# Precompute cumulative sequence lengths # Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
...@@ -287,14 +322,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -287,14 +322,11 @@ class FlashAttentionBackend(AttentionBackend):
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
) )
# Only zero out the part out of max_len_k metadata.page_table = self.req_to_token[
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0) :, self.decode_cuda_graph_metadata["strided_indices"]
# Then do the copy ]
metadata.page_table[:, : metadata.max_seq_len_k].copy_( metadata.page_table = metadata.page_table[req_pool_indices[:bs]]
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k] self.forward_metadata = metadata
)
self.forward_decode_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph.""" """Get the fill value for sequence length in CUDA graph."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment