Unverified Commit d2f8bfb2 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Follow-up fixes for flashinfer 0.0.5 (#556)

parent b7e2f800
...@@ -67,7 +67,7 @@ class InputMetadata: ...@@ -67,7 +67,7 @@ class InputMetadata:
flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
def init_flashinfer_args(self, num_attention_heads, num_key_value_heads, head_dim): def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
self.kv_indptr = torch.zeros( self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda" (self.batch_size + 1,), dtype=torch.int32, device="cuda"
) )
...@@ -102,8 +102,8 @@ class InputMetadata: ...@@ -102,8 +102,8 @@ class InputMetadata:
self.kv_indptr, self.kv_indptr,
self.kv_indices, self.kv_indices,
self.kv_last_page_len, self.kv_last_page_len,
num_attention_heads, num_qo_heads,
num_key_value_heads, num_kv_heads,
head_dim, head_dim,
1 1
) )
...@@ -113,8 +113,8 @@ class InputMetadata: ...@@ -113,8 +113,8 @@ class InputMetadata:
self.kv_indptr, self.kv_indptr,
self.kv_indices, self.kv_indices,
self.kv_last_page_len, self.kv_last_page_len,
num_attention_heads, num_qo_heads,
num_key_value_heads, num_kv_heads,
head_dim, head_dim,
1, 1,
pos_encoding_mode="NONE", pos_encoding_mode="NONE",
...@@ -203,7 +203,7 @@ class InputMetadata: ...@@ -203,7 +203,7 @@ class InputMetadata:
if global_server_args_dict.get("enable_flashinfer", False): if global_server_args_dict.get("enable_flashinfer", False):
ret.init_flashinfer_args( ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size, model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.num_key_value_heads // tp_size, model_runner.model_config.get_num_kv_heads(tp_size),
model_runner.model_config.head_dim model_runner.model_config.head_dim
) )
...@@ -350,6 +350,15 @@ class ModelRunner: ...@@ -350,6 +350,15 @@ class ModelRunner:
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
BatchDecodeWithPagedKVCacheWrapper, BatchDecodeWithPagedKVCacheWrapper,
) )
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
self.model_config.num_attention_heads // self.tp_size,
self.model_config.get_num_kv_heads(self.tp_size)):
use_tensor_cores = True
else:
use_tensor_cores = False
workspace_buffer = torch.empty( workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda" 32 * 1024 * 1024, dtype=torch.int8, device="cuda"
) )
...@@ -357,8 +366,10 @@ class ModelRunner: ...@@ -357,8 +366,10 @@ class ModelRunner:
workspace_buffer, "NHD" workspace_buffer, "NHD"
) )
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD" workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
) )
else:
self.flashinfer_prefill_wrapper = self.flashinfer_decode_wrapper = None
@torch.inference_mode() @torch.inference_mode()
def forward_prefill(self, batch: Batch): def forward_prefill(self, batch: Batch):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment