Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5949b1ca
Unverified
Commit
5949b1ca
authored
Jul 13, 2024
by
Ying Sheng
Committed by
GitHub
Jul 13, 2024
Browse files
Fix memory pool index error (#616)
parent
0feca02d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
9 additions
and
11 deletions
+9
-11
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-5
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+3
-2
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+2
-2
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+2
-2
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
5949b1ca
...
...
@@ -137,9 +137,6 @@ class RadixAttention(nn.Module):
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
key_buffer
=
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
)
key_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_k
value_buffer
=
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
)
if
input_metadata
.
out_cache_loc
is
not
None
:
key_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_k
value_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_v
else
:
raise
RuntimeError
()
value_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_v
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
5949b1ca
...
...
@@ -132,7 +132,8 @@ class CudaGraphRunner:
index
=
bisect
.
bisect_left
(
self
.
batch_size_list
,
raw_bs
)
bs
=
self
.
batch_size_list
[
index
]
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
1
)
self
.
seq_lens
.
zero_
()
self
.
position_ids_offsets
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
# Common inputs
...
...
@@ -168,4 +169,4 @@ class CudaGraphRunner:
prefill_top_logprobs
=
None
,
decode_top_logprobs
=
output
.
decode_top_logprobs
[:
raw_bs
]
if
output
.
decode_top_logprobs
is
not
None
else
None
,
)
return
output
\ No newline at end of file
return
output
python/sglang/srt/managers/controller/tp_worker.py
View file @
5949b1ca
...
...
@@ -315,7 +315,7 @@ class ModelTpServer:
def
get_new_fill_batch
(
self
)
->
Optional
[
Batch
]:
running_bs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
if
running_bs
>
self
.
max_running_requests
:
if
running_bs
>
=
self
.
max_running_requests
:
return
# Compute matched prefix length
...
...
@@ -393,7 +393,7 @@ class ModelTpServer:
else
:
break
if
running_bs
+
len
(
can_run_list
)
>
self
.
max_running_requests
:
if
running_bs
+
len
(
can_run_list
)
>
=
self
.
max_running_requests
:
break
if
len
(
can_run_list
)
==
0
:
...
...
python/sglang/srt/memory_pool.py
View file @
5949b1ca
...
...
@@ -46,7 +46,7 @@ class TokenToKVPool:
# [size, key/value, head_num, head_dim] for each layer
self
.
kv_data
=
[
torch
.
empty
((
size
,
2
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
torch
.
empty
((
size
+
1
,
2
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
...
...
@@ -127,4 +127,4 @@ class TokenToKVPool:
self
.
total_ref_ct
=
0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
add_refs
(
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
))
\ No newline at end of file
self
.
add_refs
(
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment