"tests/vscode:/vscode.git/clone" did not exist on "70ee4b5e2833b7d60509bca9f37ff3fee8cf271c"
Unverified Commit ad1dd746 authored by Qubitium's avatar Qubitium Committed by GitHub
Browse files

Fix flashinfer >= 0.0.3 compat (#282)

parent eb4308c4
import importlib import importlib
import logging import logging
import inspect
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
...@@ -124,14 +125,21 @@ class InputMetadata: ...@@ -124,14 +125,21 @@ class InputMetadata:
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD" workspace_buffer, "NHD"
) )
self.prefill_wrapper.begin_forward( args = [
self.qo_indptr, self.qo_indptr,
self.kv_indptr, self.kv_indptr,
self.kv_indices, self.kv_indices,
self.kv_last_page_len, self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size, self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size,
) ]
# flashinfer >= 0.0.3
# FIXME: Drop this when flashinfer updates to 0.0.4
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
args.append(self.model_runner.model_config.head_dim)
self.prefill_wrapper.begin_forward(*args)
else: else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD" workspace_buffer, "NHD"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment