Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
43fbb6d9
"vscode:/vscode.git/clone" did not exist on "1a5c6c84627883c9a0667cde123a70d330c4e3e1"
Unverified
Commit
43fbb6d9
authored
Aug 10, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 10, 2024
Browse files
Fix `input_ids` && rename to `fill_ids` (#1021)
parent
54fb1c80
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
28 additions
and
27 deletions
+28
-27
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+3
-3
python/sglang/srt/managers/policy_scheduler.py
python/sglang/srt/managers/policy_scheduler.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+4
-4
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+2
-2
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+2
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+5
-5
No files found.
python/sglang/bench_latency.py
View file @
43fbb6d9
...
@@ -152,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
...
@@ -152,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
)
req
=
Req
(
rid
=
i
,
origin_input_text
=
prompts
[
i
],
origin_input_ids
=
tmp_input_ids
)
req
.
prefix_indices
=
[]
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
sampling_params
=
sampling_params
req
.
input
_ids
=
req
.
origin_input_ids
req
.
fill
_ids
=
req
.
origin_input_ids
reqs
.
append
(
req
)
reqs
.
append
(
req
)
return
input_ids
,
reqs
return
input_ids
,
reqs
...
@@ -163,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test(
...
@@ -163,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test(
):
):
for
i
in
range
(
len
(
reqs
)):
for
i
in
range
(
len
(
reqs
)):
req
=
reqs
[
i
]
req
=
reqs
[
i
]
req
.
input
_ids
+=
input_ids
[
i
][
bench_args
.
cut_len
:]
req
.
fill
_ids
+=
input_ids
[
i
][
bench_args
.
cut_len
:]
req
.
prefix_indices
=
model_runner
.
req_to_token_pool
.
req_to_token
[
req
.
prefix_indices
=
model_runner
.
req_to_token_pool
.
req_to_token
[
i
,
:
bench_args
.
cut_len
i
,
:
bench_args
.
cut_len
]
]
...
@@ -182,7 +182,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
...
@@ -182,7 +182,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]))
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]))
req
.
prefix_indices
=
[]
req
.
prefix_indices
=
[]
req
.
sampling_params
=
sampling_params
req
.
sampling_params
=
sampling_params
req
.
input
_ids
=
req
.
origin_input_ids
req
.
fill
_ids
=
req
.
origin_input_ids
reqs
.
append
(
req
)
reqs
.
append
(
req
)
return
reqs
return
reqs
...
...
python/sglang/srt/managers/policy_scheduler.py
View file @
43fbb6d9
...
@@ -138,7 +138,7 @@ class PrefillAdder:
...
@@ -138,7 +138,7 @@ class PrefillAdder:
def
add_inflight_req
(
self
,
req
:
Req
):
def
add_inflight_req
(
self
,
req
:
Req
):
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
truncated
=
req
.
extend_input_len
>
self
.
rem_chunk_tokens
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
input
_ids
=
req
.
input
_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
req
.
fill
_ids
=
req
.
fill
_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
self
.
_prefill_one_req
(
self
.
_prefill_one_req
(
...
@@ -193,7 +193,7 @@ class PrefillAdder:
...
@@ -193,7 +193,7 @@ class PrefillAdder:
return
False
return
False
req
.
extend_input_len
=
trunc_len
req
.
extend_input_len
=
trunc_len
req
.
input
_ids
=
req
.
input
_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
fill
_ids
=
req
.
fill
_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
self
.
new_inflight_req
=
req
self
.
new_inflight_req
=
req
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
43fbb6d9
...
@@ -99,7 +99,7 @@ class Req:
...
@@ -99,7 +99,7 @@ class Req:
self
.
origin_input_ids_unpadded
=
origin_input_ids
# Before image padding
self
.
origin_input_ids_unpadded
=
origin_input_ids
# Before image padding
self
.
origin_input_ids
=
origin_input_ids
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
input
_ids
=
None
#
input
_ids = origin_input_ids + output_ids
self
.
fill
_ids
=
None
#
fill
_ids = origin_input_ids + output_ids
# Memory info
# Memory info
self
.
req_pool_idx
=
None
self
.
req_pool_idx
=
None
...
@@ -165,12 +165,12 @@ class Req:
...
@@ -165,12 +165,12 @@ class Req:
return
self
.
finished_reason
is
not
None
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
):
def
init_next_round_input
(
self
):
self
.
input
_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
fill
_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
extend_input_len
=
len
(
self
.
input
_ids
)
-
len
(
self
.
prefix_indices
)
self
.
extend_input_len
=
len
(
self
.
fill
_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
def
adjust_max_prefix_ids
(
self
):
self
.
input
_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
fill
_ids
=
self
.
origin_input_ids
+
self
.
output_ids
input_len
=
len
(
self
.
input
_ids
)
input_len
=
len
(
self
.
fill
_ids
)
max_prefix_len
=
input_len
max_prefix_len
=
input_len
if
self
.
sampling_params
.
max_new_tokens
>
0
:
if
self
.
sampling_params
.
max_new_tokens
>
0
:
...
@@ -184,7 +184,7 @@ class Req:
...
@@ -184,7 +184,7 @@ class Req:
# Need at least two tokens to compute normalized logprob
# Need at least two tokens to compute normalized logprob
max_prefix_len
=
min
(
max_prefix_len
,
input_len
-
2
)
max_prefix_len
=
min
(
max_prefix_len
,
input_len
-
2
)
return
self
.
input
_ids
[:
max_prefix_len
]
return
self
.
fill
_ids
[:
max_prefix_len
]
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def
init_incremental_detokenize
(
self
):
def
init_incremental_detokenize
(
self
):
...
@@ -427,7 +427,7 @@ class ScheduleBatch:
...
@@ -427,7 +427,7 @@ class ScheduleBatch:
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
bs
=
self
.
batch_size
()
bs
=
self
.
batch_size
()
reqs
=
self
.
reqs
reqs
=
self
.
reqs
input_ids
=
[
r
.
input
_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
input_ids
=
[
r
.
fill
_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
extend_num_tokens
=
sum
(
len
(
ids
)
for
ids
in
input_ids
)
seq_lens
=
[]
seq_lens
=
[]
...
@@ -438,7 +438,7 @@ class ScheduleBatch:
...
@@ -438,7 +438,7 @@ class ScheduleBatch:
pt
=
0
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
req
.
req_pool_idx
=
req_pool_indices_cpu
[
i
]
req
.
req_pool_idx
=
req_pool_indices_cpu
[
i
]
pre_len
,
seq_len
=
len
(
req
.
prefix_indices
),
len
(
req
.
input
_ids
)
pre_len
,
seq_len
=
len
(
req
.
prefix_indices
),
len
(
req
.
fill
_ids
)
ext_len
=
seq_len
-
pre_len
ext_len
=
seq_len
-
pre_len
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
...
@@ -632,7 +632,8 @@ class ScheduleBatch:
...
@@ -632,7 +632,8 @@ class ScheduleBatch:
def
prepare_for_decode
(
self
,
input_ids
=
None
):
def
prepare_for_decode
(
self
,
input_ids
=
None
):
if
input_ids
is
None
:
if
input_ids
is
None
:
input_ids
=
[
input_ids
=
[
r
.
output_ids
[
-
1
]
if
r
.
output_ids
else
r
.
input_ids
[
-
1
]
for
r
in
self
.
reqs
r
.
output_ids
[
-
1
]
if
r
.
output_ids
else
r
.
origin_input_ids
[
-
1
]
for
r
in
self
.
reqs
]
]
else
:
else
:
self
.
penalizer_orchestrator
.
cumulate_input_tokens
(
input_ids
)
self
.
penalizer_orchestrator
.
cumulate_input_tokens
(
input_ids
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
43fbb6d9
...
@@ -515,7 +515,7 @@ class ModelTpServer:
...
@@ -515,7 +515,7 @@ class ModelTpServer:
self
.
handle_finished_requests
(
batch
)
self
.
handle_finished_requests
(
batch
)
def
add_logprob_return_values
(
self
,
i
,
req
,
pt
,
next_token_ids
,
output
):
def
add_logprob_return_values
(
self
,
i
,
req
:
Req
,
pt
,
next_token_ids
,
output
):
if
req
.
normalized_prompt_logprob
is
None
:
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
req
.
normalized_prompt_logprob
=
output
.
normalized_prompt_logprobs
[
i
]
...
@@ -524,12 +524,12 @@ class ModelTpServer:
...
@@ -524,12 +524,12 @@ class ModelTpServer:
req
.
input_token_logprobs
=
list
(
req
.
input_token_logprobs
=
list
(
zip
(
zip
(
output
.
input_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
output
.
input_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input
_ids
[
-
req
.
extend_input_len
+
1
:],
req
.
fill
_ids
[
-
req
.
extend_input_len
+
1
:],
)
)
)
)
if
req
.
logprob_start_len
==
0
:
if
req
.
logprob_start_len
==
0
:
req
.
input_token_logprobs
=
[
req
.
input_token_logprobs
=
[
(
None
,
req
.
input
_ids
[
0
])
(
None
,
req
.
fill
_ids
[
0
])
]
+
req
.
input_token_logprobs
]
+
req
.
input_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
if
req
.
last_update_decode_tokens
!=
0
:
...
@@ -543,7 +543,7 @@ class ModelTpServer:
...
@@ -543,7 +543,7 @@ class ModelTpServer:
+
req
.
extend_input_len
+
req
.
extend_input_len
-
1
-
1
],
],
req
.
input
_ids
[
-
req
.
last_update_decode_tokens
+
1
:],
req
.
fill
_ids
[
-
req
.
last_update_decode_tokens
+
1
:],
)
)
)
)
)
)
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
43fbb6d9
...
@@ -34,7 +34,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -34,7 +34,7 @@ class ChunkCache(BasePrefixCache):
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
token_ids
=
(
req
.
origin_
input_ids
+
req
.
output_ids
)[:
-
1
]
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
req
.
req_pool_idx
,
:
len
(
token_ids
)
...
@@ -45,7 +45,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -45,7 +45,7 @@ class ChunkCache(BasePrefixCache):
def
cache_unfinished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
def
cache_unfinished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
req
.
input
_ids
token_ids
=
req
.
fill
_ids
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
req
.
req_pool_idx
,
:
len
(
token_ids
)
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
43fbb6d9
...
@@ -92,7 +92,7 @@ class RadixCache(BasePrefixCache):
...
@@ -92,7 +92,7 @@ class RadixCache(BasePrefixCache):
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
def
cache_finished_req
(
self
,
req
:
"Req"
,
token_ids
=
None
):
"""Cache request when it finishes."""
"""Cache request when it finishes."""
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
token_ids
=
(
req
.
origin_
input_ids
+
req
.
output_ids
)[:
-
1
]
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
req
.
req_pool_idx
,
:
len
(
token_ids
)
]
]
...
@@ -116,7 +116,7 @@ class RadixCache(BasePrefixCache):
...
@@ -116,7 +116,7 @@ class RadixCache(BasePrefixCache):
return
return
if
token_ids
is
None
:
if
token_ids
is
None
:
token_ids
=
req
.
input
_ids
token_ids
=
req
.
fill
_ids
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
token_ids
)
req
.
req_pool_idx
,
:
len
(
token_ids
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
43fbb6d9
...
@@ -109,7 +109,7 @@ class InputMetadata:
...
@@ -109,7 +109,7 @@ class InputMetadata:
self
.
positions
=
torch
.
tensor
(
self
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
np
.
concatenate
(
[
[
np
.
arange
(
len
(
req
.
prefix_indices
),
len
(
req
.
input
_ids
))
np
.
arange
(
len
(
req
.
prefix_indices
),
len
(
req
.
fill
_ids
))
for
req
in
batch
.
reqs
for
req
in
batch
.
reqs
],
],
axis
=
0
,
axis
=
0
,
...
@@ -124,7 +124,7 @@ class InputMetadata:
...
@@ -124,7 +124,7 @@ class InputMetadata:
[
[
np
.
arange
(
np
.
arange
(
len
(
req
.
prefix_indices
)
+
position_ids_offsets_cpu
[
i
],
len
(
req
.
prefix_indices
)
+
position_ids_offsets_cpu
[
i
],
len
(
req
.
input
_ids
)
+
position_ids_offsets_cpu
[
i
],
len
(
req
.
fill
_ids
)
+
position_ids_offsets_cpu
[
i
],
)
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
],
],
...
@@ -141,7 +141,7 @@ class InputMetadata:
...
@@ -141,7 +141,7 @@ class InputMetadata:
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
else
:
else
:
prefix_lens_cpu
=
[
prefix_lens_cpu
=
[
len
(
r
.
input
_ids
)
-
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
len
(
r
.
fill
_ids
)
-
len
(
r
.
prefix_indices
)
for
r
in
batch
.
reqs
]
]
self
.
extend_seq_lens
=
torch
.
tensor
(
prefix_lens_cpu
,
device
=
"cuda"
)
self
.
extend_seq_lens
=
torch
.
tensor
(
prefix_lens_cpu
,
device
=
"cuda"
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
...
@@ -149,7 +149,7 @@ class InputMetadata:
...
@@ -149,7 +149,7 @@ class InputMetadata:
self
.
extend_no_prefix
=
all
(
x
==
0
for
x
in
prefix_lens_cpu
)
self
.
extend_no_prefix
=
all
(
x
==
0
for
x
in
prefix_lens_cpu
)
def
init_total_num_tokens
(
self
,
batch
:
ScheduleBatch
):
def
init_total_num_tokens
(
self
,
batch
:
ScheduleBatch
):
self
.
total_num_tokens
=
sum
(
len
(
req
.
input
_ids
)
for
req
in
batch
.
reqs
)
self
.
total_num_tokens
=
sum
(
len
(
req
.
fill
_ids
)
for
req
in
batch
.
reqs
)
@
classmethod
@
classmethod
def
from_schedule_batch
(
def
from_schedule_batch
(
...
@@ -203,7 +203,7 @@ class InputMetadata:
...
@@ -203,7 +203,7 @@ class InputMetadata:
def
init_triton_args
(
self
,
batch
:
ScheduleBatch
,
prefix_lens
):
def
init_triton_args
(
self
,
batch
:
ScheduleBatch
,
prefix_lens
):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
self
.
triton_max_seq_len
=
max
(
len
(
r
.
input
_ids
)
for
r
in
batch
.
reqs
)
self
.
triton_max_seq_len
=
max
(
len
(
r
.
fill
_ids
)
for
r
in
batch
.
reqs
)
self
.
triton_prefix_lens
=
prefix_lens
self
.
triton_prefix_lens
=
prefix_lens
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment