Unverified Commit ad1dd746 authored by Qubitium's avatar Qubitium Committed by GitHub
Browse files

Fix flashinfer >= 0.0.3 compat (#282)

parent eb4308c4
import importlib
import logging
import inspect
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
......@@ -124,14 +125,21 @@ class InputMetadata:
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.prefill_wrapper.begin_forward(
args = [
self.qo_indptr,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
)
]
# flashinfer >= 0.0.3
# FIXME: Drop this when flashinfer updates to 0.0.4
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
args.append(self.model_runner.model_config.head_dim)
self.prefill_wrapper.begin_forward(*args)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment