Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
326df4ba
Unverified
Commit
326df4ba
authored
Aug 14, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 14, 2024
Browse files
Use a single workspace for flashinfer (#1077)
parent
6767e222
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
16 additions
and
18 deletions
+16
-18
benchmark/gsm8k/bench_sglang.py
benchmark/gsm8k/bench_sglang.py
+1
-1
python/sglang/global_config.py
python/sglang/global_config.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+6
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+7
-9
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
No files found.
benchmark/gsm8k/bench_sglang.py
View file @
326df4ba
...
...
@@ -64,7 +64,7 @@ def main(args):
@
sgl
.
function
def
few_shot_gsm8k
(
s
,
question
):
s
+=
few_shot_examples
+
question
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
"Question"
)
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
[
"Question"
,
"Assistant:"
]
)
#####################################
########## SGL Program End ##########
...
...
python/sglang/global_config.py
View file @
326df4ba
...
...
@@ -27,7 +27,7 @@ class GlobalConfig:
# Runtime constants: others
self
.
num_continue_decode_steps
=
10
self
.
retract_decode_steps
=
20
self
.
flashinfer_workspace_size
=
192
*
1024
*
1024
self
.
flashinfer_workspace_size
=
384
*
1024
*
1024
# Output tokenization configs
self
.
skip_special_tokens_in_output
=
True
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
326df4ba
...
...
@@ -120,13 +120,13 @@ class CudaGraphRunner:
)
if
model_runner
.
sliding_window_size
is
None
:
self
.
flashinfer_workspace_buffer
=
(
self
.
model_runner
.
flashinfer_workspace_buffer
s
[
0
]
self
.
model_runner
.
flashinfer_workspace_buffer
)
else
:
self
.
flashinfer_workspace_buffer
s
=
[
self
.
model_runner
.
flashinfer_workspace_buffer
s
[
0
],
self
.
model_runner
.
flashinfer_workspace_buffers
[
2
],
]
self
.
flashinfer_workspace_buffer
=
(
self
.
model_runner
.
flashinfer_workspace_buffer
)
self
.
flashinfer_kv_indptr
=
[
self
.
flashinfer_kv_indptr
,
self
.
flashinfer_kv_indptr
.
clone
(),
...
...
@@ -200,7 +200,7 @@ class CudaGraphRunner:
for
i
in
range
(
2
):
flashinfer_decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
s
[
i
]
,
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
326df4ba
...
...
@@ -318,28 +318,26 @@ class ModelRunner:
use_tensor_cores
=
False
if
self
.
sliding_window_size
is
None
:
self
.
flashinfer_workspace_buffers
=
torch
.
empty
(
2
,
self
.
flashinfer_workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
flashinfer_prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
s
[
0
]
,
"NHD"
self
.
flashinfer_workspace_buffer
,
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
s
[
1
]
,
"NHD"
self
.
flashinfer_workspace_buffer
,
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
s
[
0
]
,
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
else
:
self
.
flashinfer_workspace_buffers
=
torch
.
empty
(
4
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
...
...
@@ -350,17 +348,17 @@ class ModelRunner:
for
i
in
range
(
2
):
self
.
flashinfer_prefill_wrapper_ragged
.
append
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
s
[
2
*
i
+
0
]
,
"NHD"
self
.
flashinfer_workspace_buffer
,
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
s
[
2
*
i
+
1
]
,
"NHD"
self
.
flashinfer_workspace_buffer
,
"NHD"
)
)
self
.
flashinfer_decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
s
[
2
*
i
+
0
]
,
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
...
...
python/sglang/srt/server.py
View file @
326df4ba
...
...
@@ -381,7 +381,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.1.
4
"
,
"0.1.
5
"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment