Unverified Commit a523a3c1 authored by Mingyi's avatar Mingyi Committed by GitHub
Browse files

Reduce hardcoded logic of kernel usage (#707)

parent 9f94728f
...@@ -85,9 +85,9 @@ class RadixAttention(nn.Module): ...@@ -85,9 +85,9 @@ class RadixAttention(nn.Module):
return o return o
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata) if not input_metadata.use_ragged:
self.store_kv_cache(k, v, input_metadata)
if input_metadata.total_num_tokens <= 4096:
o = input_metadata.flashinfer_prefill_wrapper_paged.forward( o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
...@@ -122,6 +122,8 @@ class RadixAttention(nn.Module): ...@@ -122,6 +122,8 @@ class RadixAttention(nn.Module):
o, _ = merge_state(o1, s1, o2, s2) o, _ = merge_state(o1, s1, o2, s2)
self.store_kv_cache(k, v, input_metadata)
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -726,6 +726,7 @@ class InputMetadata: ...@@ -726,6 +726,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
use_ragged: bool = False
@classmethod @classmethod
def create( def create(
...@@ -741,7 +742,10 @@ class InputMetadata: ...@@ -741,7 +742,10 @@ class InputMetadata:
return_logprob=False, return_logprob=False,
skip_flashinfer_init=False, skip_flashinfer_init=False,
): ):
use_ragged = False
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
use_ragged = True
init_flashinfer_args( init_flashinfer_args(
forward_mode, forward_mode,
model_runner, model_runner,
...@@ -749,6 +753,7 @@ class InputMetadata: ...@@ -749,6 +753,7 @@ class InputMetadata:
seq_lens, seq_lens,
prefix_lens, prefix_lens,
model_runner.flashinfer_decode_wrapper, model_runner.flashinfer_decode_wrapper,
use_ragged,
) )
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
...@@ -803,6 +808,7 @@ class InputMetadata: ...@@ -803,6 +808,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged, flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper, flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
use_ragged=use_ragged,
) )
if model_runner.server_args.disable_flashinfer: if model_runner.server_args.disable_flashinfer:
...@@ -823,6 +829,7 @@ def init_flashinfer_args( ...@@ -823,6 +829,7 @@ def init_flashinfer_args(
seq_lens, seq_lens,
prefix_lens, prefix_lens,
flashinfer_decode_wrapper, flashinfer_decode_wrapper,
use_ragged=False,
): ):
"""Init auxiliary variables for FlashInfer attention backend.""" """Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
...@@ -831,10 +838,10 @@ def init_flashinfer_args( ...@@ -831,10 +838,10 @@ def init_flashinfer_args(
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
total_num_tokens = int(torch.sum(seq_lens)) total_num_tokens = int(torch.sum(seq_lens))
if forward_mode == ForwardMode.DECODE or total_num_tokens <= 4096: if use_ragged:
paged_kernel_lens = seq_lens
else:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
else:
paged_kernel_lens = seq_lens
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
...@@ -867,14 +874,15 @@ def init_flashinfer_args( ...@@ -867,14 +874,15 @@ def init_flashinfer_args(
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
model_runner.flashinfer_prefill_wrapper_ragged.end_forward() if use_ragged:
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
qo_indptr, model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr, qo_indptr,
num_qo_heads, qo_indptr,
num_kv_heads, num_qo_heads,
head_dim, num_kv_heads,
) head_dim,
)
# cached part # cached part
model_runner.flashinfer_prefill_wrapper_paged.end_forward() model_runner.flashinfer_prefill_wrapper_paged.end_forward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment