Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f06e90c2
"src/vscode:/vscode.git/clone" did not exist on "2e86a3f0235cb41b212417d84b9c2cd46d8c1297"
Unverified
Commit
f06e90c2
authored
May 26, 2024
by
Liangsheng Yin
Committed by
GitHub
May 26, 2024
Browse files
Optimize retract (#440)
parent
2cea6146
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
299 additions
and
114 deletions
+299
-114
examples/usage/json_logprobs.py
examples/usage/json_logprobs.py
+104
-0
python/sglang/global_config.py
python/sglang/global_config.py
+6
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+11
-6
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+6
-8
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-1
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+87
-47
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+84
-52
No files found.
examples/usage/json_logprobs.py
0 → 100644
View file @
f06e90c2
# NOTE: Currently this can only be run through HTTP requests.
import
json
from
concurrent.futures
import
ThreadPoolExecutor
from
json_decode
import
character_regex
from
sglang.utils
import
http_request
character_names
=
[
"Hermione Granger"
,
"Ron Weasley"
,
"Harry Potter"
]
base_url
=
"http://localhost:30000"
prompt
=
"is a character in Harry Potter. Please fill in the following information about this character.
\n
"
def
openai_api_request
(
name
):
data
=
{
"model"
:
""
,
"prompt"
:
name
+
prompt
,
"temperature"
:
0
,
"max_tokens"
:
128
,
"regex"
:
character_regex
,
"logprobs"
:
3
,
}
res
=
http_request
(
base_url
+
"/v1/completions"
,
json
=
data
).
json
()
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
# fout.write(json.dumps(res, indent=4))
logprobs
=
res
[
"choices"
][
0
][
"logprobs"
]
usage
=
res
[
"usage"
]
assert
len
(
logprobs
[
"token_logprobs"
])
==
len
(
logprobs
[
"tokens"
])
assert
len
(
logprobs
[
"token_logprobs"
])
==
len
(
logprobs
[
"top_logprobs"
])
assert
len
(
logprobs
[
"token_logprobs"
])
==
usage
[
"completion_tokens"
]
-
1
return
res
def
srt_api_request
(
name
):
data
=
{
"text"
:
name
+
prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
128
,
"regex"
:
character_regex
,
},
"return_logprob"
:
True
,
"logprob_start_len"
:
0
,
"top_logprobs_num"
:
3
,
"return_text_in_logprobs"
:
True
,
}
res
=
http_request
(
base_url
+
"/generate"
,
json
=
data
).
json
()
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
# fout.write(json.dumps(res, indent=4))
meta_info
=
res
[
"meta_info"
]
assert
len
(
meta_info
[
"prefill_token_logprobs"
])
==
len
(
meta_info
[
"prefill_top_logprobs"
]
)
assert
len
(
meta_info
[
"decode_token_logprobs"
])
==
len
(
meta_info
[
"decode_top_logprobs"
]
)
assert
len
(
meta_info
[
"prefill_token_logprobs"
])
==
meta_info
[
"prompt_tokens"
]
assert
len
(
meta_info
[
"decode_token_logprobs"
])
==
meta_info
[
"completion_tokens"
]
-
1
return
res
def
pretty_print
(
res
):
meta_info
=
res
[
"meta_info"
]
print
(
"
\n\n
"
,
"="
*
30
,
"Prefill"
,
"="
*
30
)
for
i
in
range
(
len
(
meta_info
[
"prefill_token_logprobs"
])):
print
(
f
"
{
str
(
meta_info
[
'prefill_token_logprobs'
][
i
][
2
].
encode
()):
<
20
}
"
,
end
=
""
)
top_ks
=
(
[
str
(
t
[
2
].
encode
())
for
t
in
meta_info
[
"prefill_top_logprobs"
][
i
]]
if
meta_info
[
"prefill_top_logprobs"
][
i
]
else
[]
)
for
top_k
in
top_ks
:
print
(
f
"
{
top_k
:
<
15
}
"
,
end
=
""
)
print
()
print
(
"
\n\n
"
,
"="
*
30
,
"Decode"
,
"="
*
30
)
for
i
in
range
(
len
(
meta_info
[
"decode_token_logprobs"
])):
print
(
f
"
{
str
(
meta_info
[
'decode_token_logprobs'
][
i
][
2
].
encode
()):
<
20
}
"
,
end
=
""
)
top_ks
=
[
str
(
t
[
2
].
encode
())
for
t
in
meta_info
[
"decode_top_logprobs"
][
i
]]
for
top_k
in
top_ks
:
print
(
f
"
{
top_k
:
<
15
}
"
,
end
=
""
)
print
()
print
(
res
[
"text"
])
if
__name__
==
"__main__"
:
with
ThreadPoolExecutor
()
as
executor
:
ress
=
executor
.
map
(
srt_api_request
,
character_names
)
for
res
in
ress
:
pretty_print
(
res
)
openai_api_request
(
"Hermione Granger"
)
python/sglang/global_config.py
View file @
f06e90c2
...
...
@@ -28,5 +28,11 @@ class GlobalConfig:
# Request dependency time due to network delay
self
.
request_dependency_time
=
0.03
# New generation token ratio estimation
self
.
base_new_token_ratio
=
0.4
self
.
base_min_new_token_ratio
=
0.2
self
.
new_token_ratio_decay
=
0.0001
self
.
new_token_ratio_recovery
=
0.05
global_config
=
GlobalConfig
()
python/sglang/srt/layers/logits_processor.py
View file @
f06e90c2
...
...
@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs
,
decode_top_logprobs
=
[],
[]
pt
=
0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu
=
input_metadata
.
extend_seq_lens
.
cpu
().
numpy
()
for
i
in
range
(
len
(
extend_seq_lens_cpu
)
)
:
if
extend_seq_len
s_cpu
[
i
]
==
0
:
extend_seq_lens_cpu
=
input_metadata
.
extend_seq_lens
.
tolist
()
for
i
,
extend_seq_len
in
enumerate
(
extend_seq_lens_cpu
):
if
extend_seq_len
==
0
:
prefill_top_logprobs
.
append
([])
decode_top_logprobs
.
append
([])
continue
k
=
input_metadata
.
top_logprobs_nums
[
i
]
t
=
all_logprobs
[
pt
:
pt
+
extend_seq_len
s_cpu
[
i
]
].
topk
(
k
)
t
=
all_logprobs
[
pt
:
pt
+
extend_seq_len
].
topk
(
k
)
vs_cpu
=
t
.
values
.
tolist
()
ps_cpu
=
t
.
indices
.
tolist
()
prefill_top_logprobs
.
append
(
[
list
(
zip
(
vs_cpu
[
j
],
ps_cpu
[
j
]))
for
j
in
range
(
len
(
vs_cpu
)
-
1
)]
)
decode_top_logprobs
.
append
(
list
(
zip
(
vs_cpu
[
-
1
],
ps_cpu
[
-
1
])))
pt
+=
extend_seq_lens_cpu
[
i
]
pt
+=
extend_seq_len
return
prefill_top_logprobs
,
decode_top_logprobs
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
:
InputMetadata
):
...
...
@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module):
)
if
__name__
==
"__main__"
:
def
test
()
:
all_logprobs
=
torch
.
tensor
(
# s s s
[[
0
,
1
,
2
,
3
],
[
1
,
2
,
3
,
4
],
[
2
,
3
,
4
,
5
],
[
3
,
4
,
5
,
6
],
[
4
,
5
,
6
,
7
]],
...
...
@@ -173,3 +174,7 @@ if __name__ == "__main__":
print
(
"start"
,
start
)
print
(
"end"
,
end
)
print
(
"sum_logp"
,
sum_logp
)
if
__name__
==
"__main__"
:
test
()
python/sglang/srt/managers/detokenizer_manager.py
View file @
f06e90c2
...
...
@@ -51,11 +51,6 @@ class DetokenizerManager:
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for
i
in
range
(
len
(
output_strs
)):
if
recv_obj
.
hit_stop_str
[
i
]
is
not
None
:
pos
=
output_strs
[
i
].
find
(
recv_obj
.
hit_stop_str
[
i
])
if
pos
!=
-
1
:
output_strs
[
i
]
=
output_strs
[
i
][:
pos
]
if
len
(
output_tokens
[
i
])
>
0
:
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
int
(
output_tokens
[
i
][
0
])
...
...
@@ -65,9 +60,12 @@ class DetokenizerManager:
if
first_token
.
startswith
(
"▁"
):
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
(
recv_obj
.
output_and_jump_forward_strs
[
i
]
+
output_strs
[
i
]
)
output_strs
[
i
]
=
recv_obj
.
prev_output_strs
[
i
]
+
output_strs
[
i
]
if
recv_obj
.
hit_stop_str
[
i
]
is
not
None
:
pos
=
output_strs
[
i
].
find
(
recv_obj
.
hit_stop_str
[
i
])
if
pos
!=
-
1
:
output_strs
[
i
]
=
output_strs
[
i
][:
pos
]
self
.
send_to_tokenizer
.
send_pyobj
(
BatchStrOut
(
...
...
python/sglang/srt/managers/io_struct.py
View file @
f06e90c2
...
...
@@ -106,8 +106,8 @@ class TokenizedGenerateReqInput:
@
dataclass
class
BatchTokenIDOut
:
rids
:
List
[
str
]
prev_output_strs
:
List
[
str
]
output_tokens
:
List
[
List
[
int
]]
output_and_jump_forward_strs
:
List
[
str
]
hit_stop_str
:
List
[
Optional
[
str
]]
skip_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
f06e90c2
...
...
@@ -36,15 +36,15 @@ class FinishReason(IntEnum):
class
Req
:
def
__init__
(
self
,
rid
,
input_text
,
input_ids
):
def
__init__
(
self
,
rid
,
origin_
input_text
,
origin_
input_ids
):
self
.
rid
=
rid
self
.
input_text
=
input_text
self
.
input_ids
=
input_ids
self
.
origin_input_text
=
origin_input_text
self
.
origin_input_ids
=
origin_input_ids
self
.
origin_input_ids_unpadded
=
origin_input_ids
# before image padding
self
.
prev_output_str
=
""
self
.
prev_output_ids
=
[]
self
.
output_ids
=
[]
# Since jump forward may retokenize the prompt with partial outputs,
# we maintain the original prompt length to report the correct usage.
self
.
prompt_tokens
=
len
(
input_ids
)
self
.
input_ids
=
None
# input_ids = origin_input_ids + prev_output_ids
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
...
...
@@ -76,15 +76,24 @@ class Req:
self
.
top_logprobs_num
=
0
self
.
normalized_prompt_logprob
=
None
self
.
prefill_token_logprobs
=
None
self
.
decode_token_logprobs
=
None
self
.
decode_token_logprobs
=
[]
self
.
prefill_top_logprobs
=
None
self
.
decode_top_logprobs
=
None
self
.
decode_top_logprobs
=
[]
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self
.
last_update_decode_tokens
=
0
# Constrained decoding
self
.
regex_fsm
=
None
self
.
regex_fsm_state
=
0
self
.
jump_forward_map
=
None
self
.
output_and_jump_forward_str
=
""
def
partial_decode
(
self
,
ids
):
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
ids
[
0
])
first_token
=
(
first_token
.
decode
()
if
isinstance
(
first_token
,
bytes
)
else
first_token
)
return
(
" "
if
first_token
.
startswith
(
"▁"
)
else
""
)
+
self
.
tokenizer
.
decode
(
ids
)
def
max_new_tokens
(
self
):
return
self
.
sampling_params
.
max_new_tokens
...
...
@@ -93,7 +102,10 @@ class Req:
if
self
.
finished
:
return
if
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
:
if
(
len
(
self
.
prev_output_ids
)
+
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
):
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
LENGTH
return
...
...
@@ -112,60 +124,66 @@ class Req:
)
for
stop_str
in
self
.
sampling_params
.
stop_strs
:
if
stop_str
in
tail_str
:
# FIXME: (minor) try incremental match in prev_output_str
if
stop_str
in
tail_str
or
stop_str
in
self
.
prev_output_str
:
self
.
finished
=
True
self
.
finish_reason
=
FinishReason
.
STOP_STR
self
.
hit_stop_str
=
stop_str
return
def
jump_forward_and_retokenize
(
self
,
jump_forward_str
,
next_state
):
old_output_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
)
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
self
.
output_ids
[
0
]
)
first_token
=
(
first_token
.
decode
()
if
isinstance
(
first_token
,
bytes
)
else
first_token
)
if
first_token
.
startswith
(
"▁"
):
old_output_str
=
" "
+
old_output_str
if
self
.
input_text
is
None
:
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
self
.
input_text
=
self
.
tokenizer
.
decode
(
self
.
input_ids
)
new_input_string
=
(
self
.
input_text
+
self
.
output_
and_jump_forward_
str
+
old
_output_str
cur_output_str
=
self
.
partial_decode
(
self
.
output_ids
)
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
if
self
.
origin_input_text
is
None
:
# Recovering text can only use unpadded ids
self
.
origin_input_text
=
self
.
tokenizer
.
decode
(
self
.
origin_input_ids_unpadded
)
all_text
=
(
self
.
origin_
input_text
+
self
.
prev_
output_str
+
cur
_output_str
+
jump_forward_str
)
new_input_ids
=
self
.
tokenizer
.
encode
(
new_input_string
)
if
self
.
pixel_values
is
not
None
:
# NOTE: This is a hack because the old input_ids contains the image padding
jump_forward_tokens_len
=
len
(
self
.
tokenizer
.
encode
(
jump_forward_str
))
else
:
jump_forward_tokens_len
=
(
len
(
new_input_ids
)
-
len
(
self
.
input_ids
)
-
len
(
self
.
output_ids
)
)
all_ids
=
self
.
tokenizer
.
encode
(
all_text
)
prompt_tokens
=
len
(
self
.
origin_input_ids_unpadded
)
self
.
origin_input_ids
=
all_ids
[:
prompt_tokens
]
self
.
origin_input_ids_unpadded
=
self
.
origin_input_ids
# NOTE: the output ids may not strictly correspond to the output text
old_prev_output_ids
=
self
.
prev_output_ids
self
.
prev_output_ids
=
all_ids
[
prompt_tokens
:]
self
.
prev_output_str
=
self
.
prev_output_str
+
cur_output_str
+
jump_forward_str
self
.
output_ids
=
[]
self
.
regex_fsm_state
=
next_state
if
self
.
return_logprob
:
# For fast-forward part's logprobs
k
=
0
for
i
,
old_id
in
enumerate
(
old_prev_output_ids
):
if
old_id
==
self
.
prev_output_ids
[
i
]:
k
=
k
+
1
else
:
break
self
.
decode_token_logprobs
=
self
.
decode_token_logprobs
[:
k
]
self
.
decode_top_logprobs
=
self
.
decode_top_logprobs
[:
k
]
self
.
logprob_start_len
=
prompt_tokens
+
k
self
.
last_update_decode_tokens
=
len
(
self
.
prev_output_ids
)
-
k
# print("=" * 100)
# print(f"Catch jump forward:\n{jump_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
self
.
input_ids
=
new_input_ids
self
.
output_ids
=
[]
self
.
sampling_params
.
max_new_tokens
=
max
(
self
.
sampling_params
.
max_new_tokens
-
jump_forward_tokens_len
,
0
)
self
.
regex_fsm_state
=
next_state
self
.
output_and_jump_forward_str
=
(
self
.
output_and_jump_forward_str
+
old_output_str
+
jump_forward_str
)
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100)
def
__repr__
(
self
):
return
f
"rid(n=
{
self
.
rid
}
, "
f
"input_ids=
{
self
.
input_ids
}
, "
return
f
"rid(n=
{
self
.
rid
}
, "
f
"input_ids=
{
self
.
origin_
input_ids
}
, "
@
dataclass
...
...
@@ -336,6 +354,7 @@ class Batch:
def
retract_decode
(
self
):
sorted_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))]
# TODO(lsyin): improve the priority of retraction
sorted_indices
.
sort
(
key
=
lambda
i
:
(
len
(
self
.
reqs
[
i
].
output_ids
),
-
len
(
self
.
reqs
[
i
].
input_ids
)),
reverse
=
True
,
...
...
@@ -356,18 +375,27 @@ class Batch:
][
last_uncached_pos
:
seq_lens_cpu
[
idx
]]
self
.
token_to_kv_pool
.
dec_refs
(
token_indices
)
# release the last node
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
cur_output_str
=
req
.
partial_decode
(
req
.
output_ids
)
req
.
prev_output_str
=
req
.
prev_output_str
+
cur_output_str
req
.
prev_output_ids
.
extend
(
req
.
output_ids
)
req
.
prefix_indices
=
None
req
.
last_node
=
None
req
.
extend_input_len
=
0
req
.
output_ids
=
[]
req
.
regex_fsm_state
=
0
# For incremental logprobs
req
.
last_update_decode_tokens
=
0
req
.
logprob_start_len
=
10
**
9
self
.
filter_batch
(
sorted_indices
)
return
retracted_reqs
def
check_for_jump_forward
(
self
):
def
check_for_jump_forward
(
self
,
model_runner
):
jump_forward_reqs
=
[]
filter_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))]
...
...
@@ -397,6 +425,18 @@ class Batch:
# jump-forward
req
.
jump_forward_and_retokenize
(
jump_forward_str
,
next_state
)
# re-applying image padding
if
req
.
pixel_values
is
not
None
:
(
req
.
origin_input_ids
,
req
.
image_offset
,
)
=
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
,
)
jump_forward_reqs
.
append
(
req
)
filter_indices
.
remove
(
i
)
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
f06e90c2
...
...
@@ -4,7 +4,7 @@ import multiprocessing
import
time
import
warnings
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
import
rpyc
import
torch
...
...
@@ -16,6 +16,7 @@ try:
except
ImportError
:
from
vllm.logger
import
logger
as
vllm_default_logger
from
sglang.global_config
import
global_config
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
...
...
@@ -106,7 +107,8 @@ class ModelRpcServer:
set_random_seed
(
server_args
.
random_seed
)
# Print info
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] "
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] "
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
f
"max_prefill_num_token=
{
self
.
max_prefill_num_token
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
, "
...
...
@@ -151,9 +153,20 @@ class ModelRpcServer:
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
self
.
new_token_ratio
=
min
(
0.4
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
min_new_token_ratio
=
min
(
0.2
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
new_token_ratio_step
=
(
0.0001
,
0.05
)
# (down, up)
assert
(
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
self
.
new_token_ratio
=
min
(
global_config
.
base_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
self
.
min_new_token_ratio
=
min
(
global_config
.
base_min_new_token_ratio
*
server_args
.
schedule_conservativeness
,
1.0
,
)
self
.
new_token_ratio_decay
=
global_config
.
new_token_ratio_decay
self
.
new_token_ratio_recovery
=
global_config
.
new_token_ratio_recovery
def
exposed_step
(
self
,
recv_reqs
):
if
self
.
tp_size
!=
1
:
...
...
@@ -256,8 +269,13 @@ class ModelRpcServer:
(
recv_req
.
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_size
=
recv_req
.
image_size
req
.
input_ids
,
req
.
image_offset
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
input_ids
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
req
.
origin_input_ids
,
req
.
image_offset
=
(
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
,
)
)
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
return_logprob
=
recv_req
.
return_logprob
...
...
@@ -275,11 +293,11 @@ class ModelRpcServer:
)
# Truncate prompts that are too long
req
.
input_ids
=
req
.
input_ids
[:
self
.
model_config
.
context_len
-
1
]
req
.
origin_
input_ids
=
req
.
origin_
input_ids
[:
self
.
model_config
.
context_len
-
1
]
req
.
sampling_params
.
max_new_tokens
=
min
(
req
.
sampling_params
.
max_new_tokens
,
self
.
model_config
.
context_len
-
1
-
len
(
req
.
input_ids
),
self
.
max_total_num_token
-
128
-
len
(
req
.
input_ids
),
self
.
model_config
.
context_len
-
1
-
len
(
req
.
origin_
input_ids
),
self
.
max_total_num_token
-
128
-
len
(
req
.
origin_
input_ids
),
)
self
.
forward_queue
.
append
(
req
)
...
...
@@ -292,6 +310,10 @@ class ModelRpcServer:
# Compute matched prefix length
for
req
in
self
.
forward_queue
:
assert
(
len
(
req
.
output_ids
)
==
0
),
"The output ids should be empty when prefilling"
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
prev_output_ids
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
if
req
.
return_logprob
:
prefix_indices
=
prefix_indices
[:
req
.
logprob_start_len
]
...
...
@@ -319,7 +341,7 @@ class ModelRpcServer:
)
for
req
in
self
.
forward_queue
:
if
req
.
return_logprob
:
if
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
:
# Need at least two tokens to compute normalized logprob
if
req
.
extend_input_len
<
2
:
delta
=
2
-
req
.
extend_input_len
...
...
@@ -441,28 +463,53 @@ class ModelRpcServer:
req
.
check_finished
()
if
req
.
return_logprob
:
req
.
normalized_prompt_logprob
=
normalized_prompt_logprobs
[
i
]
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req
.
prefill_token_logprobs
=
list
(
zip
(
prefill_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
extend_input_len
+
1
:],
if
req
.
normalized_prompt_logprob
is
None
:
req
.
normalized_prompt_logprob
=
normalized_prompt_logprobs
[
i
]
if
req
.
prefill_token_logprobs
is
None
:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req
.
prefill_token_logprobs
=
list
(
zip
(
prefill_token_logprobs
[
pt
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
extend_input_len
+
1
:],
)
)
)
if
req
.
logprob_start_len
==
0
:
req
.
prefill_token_logprobs
=
[
(
None
,
req
.
input_ids
[
0
])
]
+
req
.
prefill_token_logprobs
req
.
decode_token_logprobs
=
[
if
req
.
logprob_start_len
==
0
:
req
.
prefill_token_logprobs
=
[
(
None
,
req
.
input_ids
[
0
])
]
+
req
.
prefill_token_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_token_logprobs
.
extend
(
list
(
zip
(
prefill_token_logprobs
[
pt
+
req
.
extend_input_len
-
req
.
last_update_decode_tokens
:
pt
+
req
.
extend_input_len
-
1
],
req
.
input_ids
[
-
req
.
last_update_decode_tokens
+
1
:],
)
)
)
req
.
decode_token_logprobs
.
append
(
(
last_token_logprobs
[
i
],
next_token_ids
[
i
])
]
)
if
req
.
top_logprobs_num
>
0
:
req
.
prefill_top_logprobs
=
prefill_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
prefill_top_logprobs
=
[
None
]
+
req
.
prefill_top_logprobs
req
.
decode_top_logprobs
=
[
decode_top_logprobs
[
i
]]
if
req
.
prefill_top_logprobs
is
None
:
req
.
prefill_top_logprobs
=
prefill_top_logprobs
[
i
]
if
req
.
logprob_start_len
==
0
:
req
.
prefill_top_logprobs
=
[
None
]
+
req
.
prefill_top_logprobs
if
req
.
last_update_decode_tokens
!=
0
:
req
.
decode_top_logprobs
.
extend
(
prefill_top_logprobs
[
i
][
-
req
.
last_update_decode_tokens
+
1
:]
)
req
.
decode_top_logprobs
.
append
(
decode_top_logprobs
[
i
])
pt
+=
req
.
extend_input_len
...
...
@@ -484,7 +531,7 @@ class ModelRpcServer:
# check if decode out of memory
if
not
batch
.
check_decode_mem
():
old_ratio
=
self
.
new_token_ratio
self
.
new_token_ratio
=
min
(
old_ratio
+
self
.
new_token_ratio_
step
[
1
]
,
1.0
)
self
.
new_token_ratio
=
min
(
old_ratio
+
self
.
new_token_ratio_
recovery
,
1.0
)
retracted_reqs
=
batch
.
retract_decode
()
logger
.
info
(
...
...
@@ -495,26 +542,13 @@ class ModelRpcServer:
self
.
forward_queue
.
extend
(
retracted_reqs
)
else
:
self
.
new_token_ratio
=
max
(
self
.
new_token_ratio
-
self
.
new_token_ratio_
step
[
0
]
,
self
.
new_token_ratio
-
self
.
new_token_ratio_
decay
,
self
.
min_new_token_ratio
,
)
if
not
self
.
disable_regex_jump_forward
:
# check for jump-forward
jump_forward_reqs
=
batch
.
check_for_jump_forward
()
# check for image jump-forward
for
req
in
jump_forward_reqs
:
if
req
.
pixel_values
is
not
None
:
(
req
.
input_ids
,
req
.
image_offset
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
input_ids
,
req
.
pad_value
,
req
.
pixel_values
.
shape
,
req
.
image_size
,
)
jump_forward_reqs
=
batch
.
check_for_jump_forward
(
self
.
model_runner
)
self
.
forward_queue
.
extend
(
jump_forward_reqs
)
if
batch
.
is_empty
():
...
...
@@ -557,8 +591,8 @@ class ModelRpcServer:
def
handle_finished_requests
(
self
,
batch
:
Batch
):
output_rids
=
[]
prev_output_strs
=
[]
output_tokens
=
[]
output_and_jump_forward_strs
=
[]
output_hit_stop_str
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
...
...
@@ -582,8 +616,8 @@ class ModelRpcServer:
)
):
output_rids
.
append
(
req
.
rid
)
prev_output_strs
.
append
(
req
.
prev_output_str
)
output_tokens
.
append
(
req
.
output_ids
)
output_and_jump_forward_strs
.
append
(
req
.
output_and_jump_forward_str
)
output_hit_stop_str
.
append
(
req
.
hit_stop_str
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
...
...
@@ -593,10 +627,8 @@ class ModelRpcServer:
)
meta_info
=
{
"prompt_tokens"
:
req
.
prompt_tokens
,
"completion_tokens"
:
len
(
req
.
input_ids
)
+
len
(
req
.
output_ids
)
-
req
.
prompt_tokens
,
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
prev_output_ids
)
+
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
FinishReason
.
to_str
(
req
.
finish_reason
),
"hit_stop_str"
:
req
.
hit_stop_str
,
...
...
@@ -623,8 +655,8 @@ class ModelRpcServer:
self
.
out_pyobjs
.
append
(
BatchTokenIDOut
(
output_rids
,
prev_output_strs
,
output_tokens
,
output_and_jump_forward_strs
,
output_hit_stop_str
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment