Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f4e885b7
Unverified
Commit
f4e885b7
authored
Jul 07, 2024
by
Mingyi
Committed by
GitHub
Jul 07, 2024
Browse files
Reduce number of workspaces (#601)
parent
0877f1e7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+4
-4
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+2
-2
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
f4e885b7
...
...
@@ -4,6 +4,8 @@ import numpy as np
import
torch
from
torch
import
nn
from
flashinfer.cascade
import
merge_state
from
sglang.global_config
import
global_config
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
...
...
@@ -95,8 +97,6 @@ class RadixAttention(nn.Module):
return
o
def
prefill_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o1
,
s1
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
...
...
@@ -117,10 +117,10 @@ class RadixAttention(nn.Module):
logits_soft_cap
=
self
.
logit_cap
,
)
from
flashinfer.cascade
import
merge_state
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
torch
.
cuda
.
synchronize
()
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
f4e885b7
...
...
@@ -408,7 +408,7 @@ class ModelRunner:
use_tensor_cores
=
False
workspace_buffers
=
torch
.
empty
(
3
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
2
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
flashinfer_prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
)
...
...
@@ -417,7 +417,7 @@ class ModelRunner:
workspace_buffers
[
1
],
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
2
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
else
:
self
.
flashinfer_prefill_wrapper_ragged
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment