Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d2f8bfb2
Unverified
Commit
d2f8bfb2
authored
Jun 20, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 20, 2024
Browse files
Follow-up fixes for flashinfer 0.0.5 (#556)
parent
b7e2f800
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
7 deletions
+18
-7
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+18
-7
No files found.
python/sglang/srt/managers/controller/model_runner.py
View file @
d2f8bfb2
...
@@ -67,7 +67,7 @@ class InputMetadata:
...
@@ -67,7 +67,7 @@ class InputMetadata:
flashinfer_prefill_wrapper
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_
attention
_heads
,
num_k
ey_value
_heads
,
head_dim
):
def
init_flashinfer_args
(
self
,
num_
qo
_heads
,
num_k
v
_heads
,
head_dim
):
self
.
kv_indptr
=
torch
.
zeros
(
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
...
@@ -102,8 +102,8 @@ class InputMetadata:
...
@@ -102,8 +102,8 @@ class InputMetadata:
self
.
kv_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
kv_last_page_len
,
num_
attention
_heads
,
num_
qo
_heads
,
num_k
ey_value
_heads
,
num_k
v
_heads
,
head_dim
,
head_dim
,
1
1
)
)
...
@@ -113,8 +113,8 @@ class InputMetadata:
...
@@ -113,8 +113,8 @@ class InputMetadata:
self
.
kv_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
kv_last_page_len
,
num_
attention
_heads
,
num_
qo
_heads
,
num_k
ey_value
_heads
,
num_k
v
_heads
,
head_dim
,
head_dim
,
1
,
1
,
pos_encoding_mode
=
"NONE"
,
pos_encoding_mode
=
"NONE"
,
...
@@ -203,7 +203,7 @@ class InputMetadata:
...
@@ -203,7 +203,7 @@ class InputMetadata:
if
global_server_args_dict
.
get
(
"enable_flashinfer"
,
False
):
if
global_server_args_dict
.
get
(
"enable_flashinfer"
,
False
):
ret
.
init_flashinfer_args
(
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
num_k
ey_value
_heads
//
tp_size
,
model_runner
.
model_config
.
get_
num_k
v
_heads
(
tp_size
)
,
model_runner
.
model_config
.
head_dim
model_runner
.
model_config
.
head_dim
)
)
...
@@ -350,6 +350,15 @@ class ModelRunner:
...
@@ -350,6 +350,15 @@ class ModelRunner:
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchDecodeWithPagedKVCacheWrapper
,
BatchDecodeWithPagedKVCacheWrapper
,
)
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)):
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
workspace_buffer
=
torch
.
empty
(
workspace_buffer
=
torch
.
empty
(
32
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
32
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
)
...
@@ -357,8 +366,10 @@ class ModelRunner:
...
@@ -357,8 +366,10 @@ class ModelRunner:
workspace_buffer
,
"NHD"
workspace_buffer
,
"NHD"
)
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
)
else
:
self
.
flashinfer_prefill_wrapper
=
self
.
flashinfer_decode_wrapper
=
None
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_prefill
(
self
,
batch
:
Batch
):
def
forward_prefill
(
self
,
batch
:
Batch
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment