Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ad1dd746
"tests/vscode:/vscode.git/clone" did not exist on "70ee4b5e2833b7d60509bca9f37ff3fee8cf271c"
Unverified
Commit
ad1dd746
authored
Mar 12, 2024
by
Qubitium
Committed by
GitHub
Mar 12, 2024
Browse files
Fix flashinfer >= 0.0.3 compat (#282)
parent
eb4308c4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
2 deletions
+10
-2
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+10
-2
No files found.
python/sglang/srt/managers/router/model_runner.py
View file @
ad1dd746
import
importlib
import
importlib
import
logging
import
logging
import
inspect
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -124,14 +125,21 @@ class InputMetadata:
...
@@ -124,14 +125,21 @@ class InputMetadata:
self
.
prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
workspace_buffer
,
"NHD"
)
)
self
.
prefill_wrapper
.
begin_forward
(
args
=
[
self
.
qo_indptr
,
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
kv_last_page_len
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
)
]
# flashinfer >= 0.0.3
# FIXME: Drop this when flashinfer updates to 0.0.4
if
len
(
inspect
.
signature
(
self
.
prefill_wrapper
.
begin_forward
).
parameters
)
==
7
:
args
.
append
(
self
.
model_runner
.
model_config
.
head_dim
)
self
.
prefill_wrapper
.
begin_forward
(
*
args
)
else
:
else
:
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
workspace_buffer
,
"NHD"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment