You need to sign in or sign up before continuing.
Unverified Commit af4e7910 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up the usage of flashinfer (#610)

parent 519e20cf
...@@ -31,21 +31,13 @@ class RadixAttention(nn.Module): ...@@ -31,21 +31,13 @@ class RadixAttention(nn.Module):
self.layer_id = layer_id self.layer_id = layer_id
if not global_server_args_dict.get("disable_flashinfer", False): if not global_server_args_dict.get("disable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer self.extend_forward = self.extend_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer
# flashinfer now accepts float logit_cap argument
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
else: else:
self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton self.decode_forward = self.decode_forward_triton
self.logit_cap = logit_cap if logit_cap is not None else 0
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata): self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
# See the extend_forward_xxx functions.
raise NotImplementedError()
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q) o = torch.empty_like(q)
...@@ -86,7 +78,6 @@ class RadixAttention(nn.Module): ...@@ -86,7 +78,6 @@ class RadixAttention(nn.Module):
input_metadata.start_loc, input_metadata.start_loc,
input_metadata.seq_lens, input_metadata.seq_lens,
input_metadata.max_seq_len, input_metadata.max_seq_len,
input_metadata.other_kv_index,
input_metadata.total_num_tokens, input_metadata.total_num_tokens,
sm_scale=self.scaling, sm_scale=self.scaling,
logit_cap=self.logit_cap, logit_cap=self.logit_cap,
...@@ -94,7 +85,7 @@ class RadixAttention(nn.Module): ...@@ -94,7 +85,7 @@ class RadixAttention(nn.Module):
return o return o
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
......
...@@ -107,7 +107,6 @@ def _fwd_kernel_stage2( ...@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
stride_obs, stride_obs,
stride_oh, stride_oh,
stride_req_to_token_b, stride_req_to_token_b,
other_kv_index, # To fix a NAN issue
kv_group_num: tl.constexpr, kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
...@@ -138,7 +137,7 @@ def _fwd_kernel_stage2( ...@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
+ cur_batch_req_idx * stride_req_to_token_b + cur_batch_req_idx * stride_req_to_token_b
+ (start_n + offs_n), + (start_n + offs_n),
mask=(start_n + offs_n) < cur_batch_seq_len, mask=(start_n + offs_n) < cur_batch_seq_len,
other=other_kv_index, other=0,
) )
qk = tl.load( qk = tl.load(
...@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd( ...@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
other_kv_index,
): ):
BLOCK = 64 BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0] batch, head = b_seq_len.shape[0], logics.shape[0]
...@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd( ...@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
o.stride(0), o.stride(0),
o.stride(1), o.stride(1),
req_to_tokens.stride(0), req_to_tokens.stride(0),
other_kv_index,
) )
return return
...@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd( ...@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
o.stride(0), o.stride(0),
o.stride(1), o.stride(1),
req_to_tokens.stride(0), req_to_tokens.stride(0),
other_kv_index,
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=v_buffer.shape[-1], BLOCK_DMODEL=v_buffer.shape[-1],
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
...@@ -315,7 +311,6 @@ def token_attention_fwd( ...@@ -315,7 +311,6 @@ def token_attention_fwd(
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
max_len_in_batch, max_len_in_batch,
other_kv_index,
total_num_tokens, total_num_tokens,
sm_scale=None, sm_scale=None,
logit_cap=-1, logit_cap=-1,
...@@ -347,5 +342,4 @@ def token_attention_fwd( ...@@ -347,5 +342,4 @@ def token_attention_fwd(
b_req_idx, b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
other_kv_index,
) )
...@@ -729,7 +729,6 @@ class InputMetadata: ...@@ -729,7 +729,6 @@ class InputMetadata:
out_cache_cont_start: torch.Tensor = None out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
...@@ -743,24 +742,19 @@ class InputMetadata: ...@@ -743,24 +742,19 @@ class InputMetadata:
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim): def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
if ( if self.forward_mode == ForwardMode.DECODE:
self.forward_mode == ForwardMode.EXTEND paged_kernel_lens = self.seq_lens
): else:
paged_kernel_lens = self.prefix_lens paged_kernel_lens = self.prefix_lens
self.no_prefix = torch.all(self.prefix_lens == 0) self.no_prefix = torch.all(self.prefix_lens == 0)
else:
paged_kernel_lens = self.seq_lens
self.kv_indptr = torch.zeros( kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy() req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
self.kv_indices = torch.cat( kv_indices = torch.cat(
[ [
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
...@@ -769,18 +763,34 @@ class InputMetadata: ...@@ -769,18 +763,34 @@ class InputMetadata:
], ],
dim=0, dim=0,
).contiguous() ).contiguous()
kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
if self.forward_mode == ForwardMode.EXTEND: if self.forward_mode == ForwardMode.DECODE:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
else:
# extend part # extend part
self.qo_indptr = torch.zeros( qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.flashinfer_prefill_wrapper_ragged.end_forward() self.flashinfer_prefill_wrapper_ragged.end_forward()
self.flashinfer_prefill_wrapper_ragged.begin_forward( self.flashinfer_prefill_wrapper_ragged.begin_forward(
self.qo_indptr, qo_indptr,
self.qo_indptr.clone(), qo_indptr,
num_qo_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
head_dim, head_dim,
...@@ -789,28 +799,15 @@ class InputMetadata: ...@@ -789,28 +799,15 @@ class InputMetadata:
# cached part # cached part
self.flashinfer_prefill_wrapper_paged.end_forward() self.flashinfer_prefill_wrapper_paged.end_forward()
self.flashinfer_prefill_wrapper_paged.begin_forward( self.flashinfer_prefill_wrapper_paged.begin_forward(
self.qo_indptr, qo_indptr,
self.kv_indptr, kv_indptr,
self.kv_indices, kv_indices,
self.kv_last_page_len, kv_last_page_len,
num_qo_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
head_dim, head_dim,
1, 1,
) )
else:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
def init_extend_args(self): def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens self.extend_seq_lens = self.seq_lens - self.prefix_lens
...@@ -822,7 +819,6 @@ class InputMetadata: ...@@ -822,7 +819,6 @@ class InputMetadata:
def create( def create(
cls, cls,
model_runner, model_runner,
tp_size,
forward_mode, forward_mode,
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
...@@ -833,9 +829,6 @@ class InputMetadata: ...@@ -833,9 +829,6 @@ class InputMetadata:
out_cache_cont_end=None, out_cache_cont_end=None,
top_logprobs_nums=None, top_logprobs_nums=None,
return_logprob=False, return_logprob=False,
flashinfer_prefill_wrapper_ragged=None,
flashinfer_prefill_wrapper_paged=None,
flashinfer_decode_wrapper=None,
): ):
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
...@@ -845,9 +838,6 @@ class InputMetadata: ...@@ -845,9 +838,6 @@ class InputMetadata:
if forward_mode == ForwardMode.DECODE: if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else: else:
seq_lens_cpu = seq_lens.cpu().numpy() seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy() prefix_lens_cpu = prefix_lens.cpu().numpy()
...@@ -865,7 +855,6 @@ class InputMetadata: ...@@ -865,7 +855,6 @@ class InputMetadata:
), ),
device="cuda", device="cuda",
) )
other_kv_index = None
ret = cls( ret = cls(
forward_mode=forward_mode, forward_mode=forward_mode,
...@@ -882,12 +871,11 @@ class InputMetadata: ...@@ -882,12 +871,11 @@ class InputMetadata:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start, out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end, out_cache_cont_end=out_cache_cont_end,
other_kv_index=other_kv_index,
return_logprob=return_logprob, return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums, top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged, flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=flashinfer_decode_wrapper, flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
) )
if forward_mode == ForwardMode.EXTEND: if forward_mode == ForwardMode.EXTEND:
...@@ -895,8 +883,8 @@ class InputMetadata: ...@@ -895,8 +883,8 @@ class InputMetadata:
if not global_server_args_dict.get("disable_flashinfer", False): if not global_server_args_dict.get("disable_flashinfer", False):
ret.init_flashinfer_args( ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size, model_runner.model_config.num_attention_heads // model_runner.tp_size,
model_runner.model_config.get_num_kv_heads(tp_size), model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
model_runner.model_config.head_dim, model_runner.model_config.head_dim,
) )
......
...@@ -221,7 +221,6 @@ class ModelRunner: ...@@ -221,7 +221,6 @@ class ModelRunner:
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
...@@ -229,9 +228,6 @@ class ModelRunner: ...@@ -229,9 +228,6 @@ class ModelRunner:
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
...@@ -242,7 +238,6 @@ class ModelRunner: ...@@ -242,7 +238,6 @@ class ModelRunner:
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size,
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
...@@ -252,9 +247,6 @@ class ModelRunner: ...@@ -252,9 +247,6 @@ class ModelRunner:
out_cache_cont_end=batch.out_cache_cont_end, out_cache_cont_end=batch.out_cache_cont_end,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
......
...@@ -53,6 +53,7 @@ class ServerArgs: ...@@ -53,6 +53,7 @@ class ServerArgs:
disable_flashinfer: bool = False disable_flashinfer: bool = False
disable_radix_cache: bool = False disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
attention_reduce_in_fp32: bool = False attention_reduce_in_fp32: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
...@@ -294,6 +295,11 @@ class ServerArgs: ...@@ -294,6 +295,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable regex jump-forward", help="Disable regex jump-forward",
) )
parser.add_argument(
"--disable-cuda-graph",
action="store_true",
help="Disable cuda graph.",
)
parser.add_argument( parser.add_argument(
"--disable-disk-cache", "--disable-disk-cache",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment