Unverified Commit 14cb544d authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix] fix flashinfer usage for window attention (#1107)

parent e86b1ccb
......@@ -120,12 +120,9 @@ class RadixAttention(nn.Module):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
if self.sliding_window_size != -1:
prefill_wrapper_ragged = prefill_wrapper_ragged[0]
if self.sliding_window_size != -1 or self.reuse:
prefill_wrapper_paged = prefill_wrapper_paged[0]
else:
if isinstance(prefill_wrapper_ragged, list):
prefill_wrapper_ragged = prefill_wrapper_ragged[1]
if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1]
......
......@@ -324,9 +324,11 @@ def update_flashinfer_indices(
else:
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
for wrapper_id in range(2):
if flashinfer_use_ragged:
if flashinfer_use_ragged and wrapper_id == 1:
# full attention use ragged+paged
paged_kernel_lens = prefix_lens
else:
# window attention use paged only
paged_kernel_lens = seq_lens
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
......@@ -374,13 +376,9 @@ def update_flashinfer_indices(
)
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
if flashinfer_use_ragged:
model_runner.flashinfer_prefill_wrapper_ragged[
wrapper_id
].end_forward()
model_runner.flashinfer_prefill_wrapper_ragged[
wrapper_id
].begin_forward(
if flashinfer_use_ragged and wrapper_id == 1:
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
num_qo_heads,
......
......@@ -342,15 +342,14 @@ class ModelRunner:
dtype=torch.uint8,
device="cuda",
)
self.flashinfer_prefill_wrapper_ragged = []
self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged = []
self.flashinfer_decode_wrapper = []
for i in range(2):
self.flashinfer_prefill_wrapper_ragged.append(
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment