Unverified Commit 26c34941 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Submodule] Change FlashInfer to import (#156)

parent cb8e1982
[submodule "3rdparty/flashinfer"]
path = 3rdparty/flashinfer
url = https://github.com/flashinfer-ai/flashinfer.git
Subproject commit 88b9496e1a726ddb353eb42887cfc0ab32c99460
...@@ -5,13 +5,15 @@ It can be used in SGLang runtime to accelerate attention computation. ...@@ -5,13 +5,15 @@ It can be used in SGLang runtime to accelerate attention computation.
### Install flashinfer ### Install flashinfer
Note: The compilation can take a very long time. You can install flashinfer via pip as follows for CUDA 12.1.
```bash ```bash
git submodule update --init --recursive pip install flashinfer -i https://flashinfer.ai/whl/cu121/
pip install 3rdparty/flashinfer/python
``` ```
You can look for other CUDA versions in https://github.com/flashinfer-ai/flashinfer?tab=readme-ov-file#installation. If there is no desire version for your environment,
please build it from source (the compilation takes a long time).
### Run a Server With Flashinfer Mode ### Run a Server With Flashinfer Mode
Add `--model-mode flashinfer` argument to enable flashinfer when launching a server. Add `--model-mode flashinfer` argument to enable flashinfer when launching a server.
......
...@@ -98,12 +98,7 @@ class RadixAttention(nn.Module): ...@@ -98,12 +98,7 @@ class RadixAttention(nn.Module):
o = input_metadata.prefill_wrapper.forward( o = input_metadata.prefill_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.qo_indptr,
input_metadata.token_to_kv_pool.kv_data[self.layer_id], input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
allow_fp16_qk_reduction=True,
) )
return o.view(-1, self.tp_q_head_num * self.head_dim) return o.view(-1, self.tp_q_head_num * self.head_dim)
...@@ -114,9 +109,6 @@ class RadixAttention(nn.Module): ...@@ -114,9 +109,6 @@ class RadixAttention(nn.Module):
o = input_metadata.decode_wrapper.forward( o = input_metadata.decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id], input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
) )
return o.view(-1, self.tp_q_head_num * self.head_dim) return o.view(-1, self.tp_q_head_num * self.head_dim)
......
...@@ -90,6 +90,11 @@ class InputMetadata: ...@@ -90,6 +90,11 @@ class InputMetadata:
decode_wrapper = None decode_wrapper = None
def init_flashinfer_args(self, tp_size): def init_flashinfer_args(self, tp_size):
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
self.kv_indptr = torch.zeros( self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
...@@ -107,11 +112,7 @@ class InputMetadata: ...@@ -107,11 +112,7 @@ class InputMetadata:
(self.batch_size,), dtype=torch.int32, device="cuda" (self.batch_size,), dtype=torch.int32, device="cuda"
) )
from flashinfer.ops import ( workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
if ( if (
self.forward_mode == ForwardMode.PREFILL self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND or self.forward_mode == ForwardMode.EXTEND
...@@ -120,19 +121,21 @@ class InputMetadata: ...@@ -120,19 +121,21 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0) self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper() self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.prefill_wrapper.begin_forward( self.prefill_wrapper.begin_forward(
self.qo_indptr, self.qo_indptr,
self.batch_size, self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size, self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size,
) )
else: else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper() self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.decode_wrapper.begin_forward( self.decode_wrapper.begin_forward(
self.kv_indptr, self.kv_indptr,
self.kv_indices,
self.kv_last_page_len, self.kv_last_page_len,
self.batch_size,
self.model_runner.model_config.num_attention_heads // tp_size, self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim, self.model_runner.model_config.head_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment