Unverified Commit 619bb6dd authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Dispatch flashinfer wrappers (#1550)

parent b88ea90d
...@@ -53,39 +53,44 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -53,39 +53,44 @@ class FlashInferAttnBackend(AttentionBackend):
device="cuda", device="cuda",
) )
if model_runner.sliding_window_size is None: if model_runner.sliding_window_size is not None:
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.num_wrappers = 2
self.workspace_buffer, "NHD"
)
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
)
else: else:
# Two wrappers: one for sliding window attention and one for full attention. self.num_wrappers = 1
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self.prefill_wrapper_ragged = None # NOTE: we do not use ragged attention when there are multiple wrappers
self.prefill_wrapper_paged = [] self.prefill_wrapper_ragged = (
self.decode_wrapper = [] BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
for _ in range(2): if self.num_wrappers == 1
self.prefill_wrapper_paged.append( else None
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") )
)
self.decode_wrapper.append( # Two wrappers: one for sliding window attention and one for full attention.
BatchDecodeWithPagedKVCacheWrapper( # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self.workspace_buffer, self.prefill_wrappers_paged = []
"NHD", self.decode_wrappers = []
use_tensor_cores=self.decode_use_tensor_cores, for _ in range(self.num_wrappers):
) self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=self.decode_use_tensor_cores,
) )
)
self.forward_metadata = None self.forward_metadata = None
self.cuda_graph_metadata = {} self.cuda_graph_metadata = {}
def _get_wrapper_idx(self, layer: nn.Module):
if self.num_wrappers == 1:
return 0
# TODO: make sure the idx is related to sliding window size
return layer.sliding_window_size == -1
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
prefix_lens = None prefix_lens = None
...@@ -99,7 +104,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -99,7 +104,7 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged = False use_ragged = False
if ( if (
torch.sum(forward_batch.seq_lens).item() >= 4096 torch.sum(forward_batch.seq_lens).item() >= 4096
and self.model_runner.sliding_window_size is None and self.num_wrappers == 1
): ):
use_ragged = True use_ragged = True
...@@ -119,7 +124,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -119,7 +124,7 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged, use_ragged,
extend_no_prefix, extend_no_prefix,
total_num_tokens, total_num_tokens,
self.decode_wrapper, self.decode_wrappers,
) )
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
...@@ -135,45 +140,30 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -135,45 +140,30 @@ class FlashInferAttnBackend(AttentionBackend):
(max_bs,), dtype=torch.int32, device="cuda" (max_bs,), dtype=torch.int32, device="cuda"
) )
if self.model_runner.sliding_window_size is not None: # NOTE: the buffers are always in the form of list
self.cuda_graph_kv_indptr = [ self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [
self.cuda_graph_kv_indptr, self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1)
self.cuda_graph_kv_indptr.clone(), ]
] self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [
self.cuda_graph_kv_indices = [ self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
self.cuda_graph_kv_indices, ]
self.cuda_graph_kv_indices.clone(),
]
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices, seq_lens
): ):
if self.model_runner.sliding_window_size is None: decode_wrappers = []
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( for i in range(self.num_wrappers):
self.workspace_buffer, decode_wrappers.append(
"NHD", BatchDecodeWithPagedKVCacheWrapper(
use_cuda_graph=True, self.workspace_buffer,
use_tensor_cores=self.decode_use_tensor_cores, "NHD",
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1], use_cuda_graph=True,
paged_kv_indices_buffer=self.cuda_graph_kv_indices, use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs], paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
) paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
else: paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
decode_wrapper = []
for i in range(2):
decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
:bs
],
)
) )
)
update_flashinfer_indices( update_flashinfer_indices(
ForwardMode.DECODE, ForwardMode.DECODE,
...@@ -181,12 +171,12 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -181,12 +171,12 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
None, None,
decode_wrapper, decode_wrappers,
) )
self.cuda_graph_metadata[bs] = decode_wrapper self.cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = (False, False, None, decode_wrapper) self.forward_metadata = (False, False, None, decode_wrappers)
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices, seq_lens
...@@ -204,17 +194,11 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -204,17 +194,11 @@ class FlashInferAttnBackend(AttentionBackend):
return 0 return 0
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
if not isinstance(self.prefill_wrapper_paged, list): prefill_wrapper_paged = self.prefill_wrappers_paged[
prefill_wrapper_paged = self.prefill_wrapper_paged self._get_wrapper_idx(layer)
else: ]
if layer.sliding_window_size != -1:
prefill_wrapper_paged = self.prefill_wrapper_paged[0]
else:
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = ( use_ragged, extend_no_prefix, _, _ = self.forward_metadata
self.forward_metadata
)
if not use_ragged: if not use_ragged:
if k is not None: if k is not None:
...@@ -260,15 +244,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -260,15 +244,7 @@ class FlashInferAttnBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch): def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = ( decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)]
self.forward_metadata
)
if isinstance(decode_wrapper, list):
if layer.sliding_window_size != -1:
decode_wrapper = decode_wrapper[0]
else:
decode_wrapper = decode_wrapper[1]
if k is not None: if k is not None:
assert v is not None assert v is not None
......
...@@ -47,7 +47,7 @@ class FlashinferUpdater: ...@@ -47,7 +47,7 @@ class FlashinferUpdater:
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
decode_wrapper=None, decode_wrappers=None,
use_ragged=False, use_ragged=False,
): ):
self.forward_mode = forward_mode self.forward_mode = forward_mode
...@@ -66,14 +66,14 @@ class FlashinferUpdater: ...@@ -66,14 +66,14 @@ class FlashinferUpdater:
self.head_dim = model_runner.model_config.head_dim self.head_dim = model_runner.model_config.head_dim
self.batch_size = len(req_pool_indices) self.batch_size = len(req_pool_indices)
self.decode_wrapper = ( self.decode_wrappers = (
decode_wrapper or self.model_runner.attn_backend.decode_wrapper decode_wrappers or self.model_runner.attn_backend.decode_wrappers
) )
self.prefill_wrapper_ragged = ( self.prefill_wrapper_ragged = (
self.model_runner.attn_backend.prefill_wrapper_ragged self.model_runner.attn_backend.prefill_wrapper_ragged
) )
self.prefill_wrapper_paged = ( self.prefill_wrappers_paged = (
self.model_runner.attn_backend.prefill_wrapper_paged self.model_runner.attn_backend.prefill_wrappers_paged
) )
self.kv_last_page_len = torch.ones( self.kv_last_page_len = torch.ones(
...@@ -142,6 +142,7 @@ class FlashinferUpdater: ...@@ -142,6 +142,7 @@ class FlashinferUpdater:
) )
def _update_decode_indices(self, decode_wrapper): def _update_decode_indices(self, decode_wrapper):
assert not isinstance(decode_wrapper, list)
decode_wrapper.end_forward() decode_wrapper.end_forward()
decode_wrapper.begin_forward( decode_wrapper.begin_forward(
self.kv_indptr, self.kv_indptr,
...@@ -156,6 +157,9 @@ class FlashinferUpdater: ...@@ -156,6 +157,9 @@ class FlashinferUpdater:
) )
def _update_extend_indices(self, ragged_wrapper, paged_wrapper): def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
assert not isinstance(paged_wrapper, list)
assert not isinstance(ragged_wrapper, list)
# extend part # extend part
qo_indptr = torch.zeros( qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
...@@ -189,11 +193,11 @@ class FlashinferUpdater: ...@@ -189,11 +193,11 @@ class FlashinferUpdater:
self._init_indices_no_sliding_window() self._init_indices_no_sliding_window()
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
self._update_decode_indices(self.decode_wrapper) self._update_decode_indices(self.decode_wrappers[0])
else: else:
self._update_extend_indices( self._update_extend_indices(
self.prefill_wrapper_ragged, self.prefill_wrapper_ragged,
self.prefill_wrapper_paged, self.prefill_wrappers_paged[0],
) )
def update_indices_sliding_window(self): def update_indices_sliding_window(self):
...@@ -202,11 +206,11 @@ class FlashinferUpdater: ...@@ -202,11 +206,11 @@ class FlashinferUpdater:
for wrapper_id in range(2): for wrapper_id in range(2):
self._init_indices_sliding_window(wrapper_id) self._init_indices_sliding_window(wrapper_id)
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
self._update_decode_indices(self.decode_wrapper[wrapper_id]) self._update_decode_indices(self.decode_wrappers[wrapper_id])
else: else:
self._update_extend_indices( self._update_extend_indices(
None, None,
self.prefill_wrapper_paged[wrapper_id], self.prefill_wrappers_paged[wrapper_id],
) )
...@@ -216,7 +220,7 @@ def update_flashinfer_indices( ...@@ -216,7 +220,7 @@ def update_flashinfer_indices(
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
decode_wrapper=None, decode_wrappers=None,
use_ragged=False, use_ragged=False,
): ):
updater = FlashinferUpdater( updater = FlashinferUpdater(
...@@ -225,7 +229,7 @@ def update_flashinfer_indices( ...@@ -225,7 +229,7 @@ def update_flashinfer_indices(
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
decode_wrapper, decode_wrappers,
use_ragged, use_ragged,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment