Unverified Commit d2f8bfb2 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Follow-up fixes for flashinfer 0.0.5 (#556)

parent b7e2f800
......@@ -67,7 +67,7 @@ class InputMetadata:
flashinfer_prefill_wrapper: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
def init_flashinfer_args(self, num_attention_heads, num_key_value_heads, head_dim):
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
......@@ -102,8 +102,8 @@ class InputMetadata:
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_attention_heads,
num_key_value_heads,
num_qo_heads,
num_kv_heads,
head_dim,
1
)
......@@ -113,8 +113,8 @@ class InputMetadata:
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_attention_heads,
num_key_value_heads,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
......@@ -203,7 +203,7 @@ class InputMetadata:
if global_server_args_dict.get("enable_flashinfer", False):
ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.num_key_value_heads // tp_size,
model_runner.model_config.get_num_kv_heads(tp_size),
model_runner.model_config.head_dim
)
......@@ -350,6 +350,15 @@ class ModelRunner:
BatchPrefillWithPagedKVCacheWrapper,
BatchDecodeWithPagedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
self.model_config.num_attention_heads // self.tp_size,
self.model_config.get_num_kv_heads(self.tp_size)):
use_tensor_cores = True
else:
use_tensor_cores = False
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
......@@ -357,8 +366,10 @@ class ModelRunner:
workspace_buffer, "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
)
else:
self.flashinfer_prefill_wrapper = self.flashinfer_decode_wrapper = None
@torch.inference_mode()
def forward_prefill(self, batch: Batch):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment