Unverified Commit af4e7910 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up the usage of flashinfer (#610)

parent 519e20cf
...@@ -31,21 +31,13 @@ class RadixAttention(nn.Module): ...@@ -31,21 +31,13 @@ class RadixAttention(nn.Module):
self.layer_id = layer_id self.layer_id = layer_id
if not global_server_args_dict.get("disable_flashinfer", False): if not global_server_args_dict.get("disable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer self.extend_forward = self.extend_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer
# flashinfer now accepts float logit_cap argument
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
else: else:
self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap if logit_cap is not None else 0
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata): self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
# See the extend_forward_xxx functions.
raise NotImplementedError()
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q) o = torch.empty_like(q)
...@@ -86,7 +78,6 @@ class RadixAttention(nn.Module): ...@@ -86,7 +78,6 @@ class RadixAttention(nn.Module):
input_metadata.start_loc, input_metadata.start_loc,
input_metadata.seq_lens, input_metadata.seq_lens,
input_metadata.max_seq_len, input_metadata.max_seq_len,
input_metadata.other_kv_index,
input_metadata.total_num_tokens, input_metadata.total_num_tokens,
sm_scale=self.scaling, sm_scale=self.scaling,
logit_cap=self.logit_cap, logit_cap=self.logit_cap,
...@@ -94,7 +85,7 @@ class RadixAttention(nn.Module): ...@@ -94,7 +85,7 @@ class RadixAttention(nn.Module):
return o return o
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
......
...@@ -107,7 +107,6 @@ def _fwd_kernel_stage2( ...@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
stride_obs, stride_obs,
stride_oh, stride_oh,
stride_req_to_token_b, stride_req_to_token_b,
other_kv_index, # To fix a NAN issue
kv_group_num: tl.constexpr, kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
...@@ -138,7 +137,7 @@ def _fwd_kernel_stage2( ...@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
+ cur_batch_req_idx * stride_req_to_token_b + cur_batch_req_idx * stride_req_to_token_b
+ (start_n + offs_n), + (start_n + offs_n),
mask=(start_n + offs_n) < cur_batch_seq_len, mask=(start_n + offs_n) < cur_batch_seq_len,
other=other_kv_index, other=0,
) )
qk = tl.load( qk = tl.load(
...@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd( ...@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
other_kv_index,
): ):
BLOCK = 64 BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0] batch, head = b_seq_len.shape[0], logics.shape[0]
...@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd( ...@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
o.stride(0), o.stride(0),
o.stride(1), o.stride(1),
req_to_tokens.stride(0), req_to_tokens.stride(0),
other_kv_index,
) )
return return
...@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd( ...@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
o.stride(0), o.stride(0),
o.stride(1), o.stride(1),
req_to_tokens.stride(0), req_to_tokens.stride(0),
other_kv_index,
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=v_buffer.shape[-1], BLOCK_DMODEL=v_buffer.shape[-1],
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
...@@ -315,7 +311,6 @@ def token_attention_fwd( ...@@ -315,7 +311,6 @@ def token_attention_fwd(
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
max_len_in_batch, max_len_in_batch,
other_kv_index,
total_num_tokens, total_num_tokens,
sm_scale=None, sm_scale=None,
logit_cap=-1, logit_cap=-1,
...@@ -347,5 +342,4 @@ def token_attention_fwd( ...@@ -347,5 +342,4 @@ def token_attention_fwd(
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
other_kv_index,
) )
...@@ -729,7 +729,6 @@ class InputMetadata: ...@@ -729,7 +729,6 @@ class InputMetadata:
out_cache_cont_start: torch.Tensor = None out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
...@@ -743,24 +742,19 @@ class InputMetadata: ...@@ -743,24 +742,19 @@ class InputMetadata:
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
if ( if self.forward_mode == ForwardMode.DECODE:
self.forward_mode == ForwardMode.EXTEND paged_kernel_lens = self.seq_lens
): else:
paged_kernel_lens = self.prefix_lens paged_kernel_lens = self.prefix_lens
self.no_prefix = torch.all(self.prefix_lens == 0) self.no_prefix = torch.all(self.prefix_lens == 0)
else:
paged_kernel_lens = self.seq_lens
self.kv_indptr = torch.zeros( kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
self.kv_indices = torch.cat( kv_indices = torch.cat(
[ [
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
...@@ -769,18 +763,34 @@ class InputMetadata: ...@@ -769,18 +763,34 @@ class InputMetadata:
], ],
dim=0, dim=0,
).contiguous() ).contiguous()
kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
if self.forward_mode == ForwardMode.EXTEND: if self.forward_mode == ForwardMode.DECODE:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
else:
# extend part # extend part
self.qo_indptr = torch.zeros( qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.flashinfer_prefill_wrapper_ragged.end_forward() self.flashinfer_prefill_wrapper_ragged.end_forward()
self.flashinfer_prefill_wrapper_ragged.begin_forward( self.flashinfer_prefill_wrapper_ragged.begin_forward(
self.qo_indptr, qo_indptr,
self.qo_indptr.clone(), qo_indptr,
num_qo_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
head_dim, head_dim,
...@@ -789,28 +799,15 @@ class InputMetadata: ...@@ -789,28 +799,15 @@ class InputMetadata:
# cached part # cached part
self.flashinfer_prefill_wrapper_paged.end_forward() self.flashinfer_prefill_wrapper_paged.end_forward()
self.flashinfer_prefill_wrapper_paged.begin_forward( self.flashinfer_prefill_wrapper_paged.begin_forward(
self.qo_indptr, qo_indptr,
self.kv_indptr, kv_indptr,
self.kv_indices, kv_indices,
self.kv_last_page_len, kv_last_page_len,
num_qo_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
head_dim, head_dim,
1, 1,
) )
else:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
def init_extend_args(self): def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens self.extend_seq_lens = self.seq_lens - self.prefix_lens
...@@ -822,7 +819,6 @@ class InputMetadata: ...@@ -822,7 +819,6 @@ class InputMetadata:
def create( def create(
cls, cls,
model_runner, model_runner,
tp_size,
forward_mode, forward_mode,
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
...@@ -833,9 +829,6 @@ class InputMetadata: ...@@ -833,9 +829,6 @@ class InputMetadata:
out_cache_cont_end=None, out_cache_cont_end=None,
top_logprobs_nums=None, top_logprobs_nums=None,
return_logprob=False, return_logprob=False,
flashinfer_prefill_wrapper_ragged=None,
flashinfer_prefill_wrapper_paged=None,
flashinfer_decode_wrapper=None,
): ):
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
...@@ -845,9 +838,6 @@ class InputMetadata: ...@@ -845,9 +838,6 @@ class InputMetadata:
if forward_mode == ForwardMode.DECODE: if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else: else:
seq_lens_cpu = seq_lens.cpu().numpy() seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy() prefix_lens_cpu = prefix_lens.cpu().numpy()
...@@ -865,7 +855,6 @@ class InputMetadata: ...@@ -865,7 +855,6 @@ class InputMetadata:
), ),
device="cuda", device="cuda",
) )
other_kv_index = None
ret = cls( ret = cls(
forward_mode=forward_mode, forward_mode=forward_mode,
...@@ -882,12 +871,11 @@ class InputMetadata: ...@@ -882,12 +871,11 @@ class InputMetadata:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start, out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end, out_cache_cont_end=out_cache_cont_end,
other_kv_index=other_kv_index,
return_logprob=return_logprob, return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums, top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged, flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=flashinfer_decode_wrapper, flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
) )
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:
...@@ -895,8 +883,8 @@ class InputMetadata: ...@@ -895,8 +883,8 @@ class InputMetadata:
if not global_server_args_dict.get("disable_flashinfer", False): if not global_server_args_dict.get("disable_flashinfer", False):
ret.init_flashinfer_args( ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size, model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(tp_size), model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
model_runner.model_config.head_dim, model_runner.model_config.head_dim,
) )
......
...@@ -221,7 +221,6 @@ class ModelRunner: ...@@ -221,7 +221,6 @@ class ModelRunner:
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
...@@ -229,9 +228,6 @@ class ModelRunner: ...@@ -229,9 +228,6 @@ class ModelRunner:
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
...@@ -242,7 +238,6 @@ class ModelRunner: ...@@ -242,7 +238,6 @@ class ModelRunner:
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size,
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
...@@ -252,9 +247,6 @@ class ModelRunner: ...@@ -252,9 +247,6 @@ class ModelRunner:
out_cache_cont_end=batch.out_cache_cont_end, out_cache_cont_end=batch.out_cache_cont_end,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
......
...@@ -53,6 +53,7 @@ class ServerArgs: ...@@ -53,6 +53,7 @@ class ServerArgs:
disable_flashinfer: bool = False disable_flashinfer: bool = False
disable_radix_cache: bool = False disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
attention_reduce_in_fp32: bool = False attention_reduce_in_fp32: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
...@@ -294,6 +295,11 @@ class ServerArgs: ...@@ -294,6 +295,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable regex jump-forward", help="Disable regex jump-forward",
) )
parser.add_argument(
"--disable-cuda-graph",
action="store_true",
help="Disable cuda graph.",
)
parser.add_argument( parser.add_argument(
"--disable-disk-cache", "--disable-disk-cache",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment