Unverified Commit 14cb544d authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix] fix flashinfer usage for window attention (#1107)

parent e86b1ccb
...@@ -120,12 +120,9 @@ class RadixAttention(nn.Module): ...@@ -120,12 +120,9 @@ class RadixAttention(nn.Module):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs # using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
if self.sliding_window_size != -1: if self.sliding_window_size != -1 or self.reuse:
prefill_wrapper_ragged = prefill_wrapper_ragged[0]
prefill_wrapper_paged = prefill_wrapper_paged[0] prefill_wrapper_paged = prefill_wrapper_paged[0]
else: else:
if isinstance(prefill_wrapper_ragged, list):
prefill_wrapper_ragged = prefill_wrapper_ragged[1]
if isinstance(prefill_wrapper_paged, list): if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1] prefill_wrapper_paged = prefill_wrapper_paged[1]
......
...@@ -324,9 +324,11 @@ def update_flashinfer_indices( ...@@ -324,9 +324,11 @@ def update_flashinfer_indices(
else: else:
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
for wrapper_id in range(2): for wrapper_id in range(2):
if flashinfer_use_ragged: if flashinfer_use_ragged and wrapper_id == 1:
# full attention use ragged+paged
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
else: else:
# window attention use paged only
paged_kernel_lens = seq_lens paged_kernel_lens = seq_lens
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE: if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
...@@ -374,13 +376,9 @@ def update_flashinfer_indices( ...@@ -374,13 +376,9 @@ def update_flashinfer_indices(
) )
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
if flashinfer_use_ragged: if flashinfer_use_ragged and wrapper_id == 1:
model_runner.flashinfer_prefill_wrapper_ragged[ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
wrapper_id model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
].end_forward()
model_runner.flashinfer_prefill_wrapper_ragged[
wrapper_id
].begin_forward(
qo_indptr, qo_indptr,
qo_indptr, qo_indptr,
num_qo_heads, num_qo_heads,
......
...@@ -342,15 +342,14 @@ class ModelRunner: ...@@ -342,15 +342,14 @@ class ModelRunner:
dtype=torch.uint8, dtype=torch.uint8,
device="cuda", device="cuda",
) )
self.flashinfer_prefill_wrapper_ragged = [] self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged = [] self.flashinfer_prefill_wrapper_paged = []
self.flashinfer_decode_wrapper = [] self.flashinfer_decode_wrapper = []
for i in range(2): for i in range(2):
self.flashinfer_prefill_wrapper_ragged.append(
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged.append( self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper( BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD" self.flashinfer_workspace_buffer, "NHD"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment