Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0736b270
Unverified
Commit
0736b270
authored
Jul 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 27, 2024
Browse files
[Minor] Improve the code style in TokenizerManager (#767)
parent
3fdab919
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
37 deletions
+60
-37
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+1
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+54
-35
python/sglang/test/test_programs.py
python/sglang/test/test_programs.py
+5
-1
No files found.
python/sglang/srt/managers/controller/infer_batch.py
View file @
0736b270
...
...
@@ -376,7 +376,7 @@ class Batch:
logit_bias
=
torch
.
zeros
(
(
bs
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
)
logit_bias
[
i
]
=
int_token_logit_bias
logit_bias
[
i
]
[:
len
(
int_token_logit_bias
)]
=
int_token_logit_bias
# Set fields
self
.
input_ids
=
torch
.
tensor
(
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
0736b270
...
...
@@ -133,24 +133,10 @@ class TokenizerManager:
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
yield
response
async
def
_handle_single_request
(
self
,
obj
,
request
,
index
=
None
,
is_prefill
=
False
):
if
is_prefill
:
if
isinstance
(
obj
.
text
,
list
):
input_text
=
obj
.
text
[
index
]
rid
=
obj
.
rid
[
index
]
else
:
input_text
=
obj
.
text
rid
=
obj
.
rid
[
0
]
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
0
]
)
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
else
:
async
def
_handle_single_request
(
self
,
obj
,
request
,
index
=
None
,
is_cache_for_prefill
=
False
):
if
not
is_cache_for_prefill
:
rid
=
obj
.
rid
if
index
is
None
else
obj
.
rid
[
index
]
input_text
=
obj
.
text
if
index
is
None
else
obj
.
text
[
index
]
input_ids
=
(
...
...
@@ -177,6 +163,22 @@ class TokenizerManager:
top_logprobs_num
=
(
obj
.
top_logprobs_num
if
index
is
None
else
obj
.
top_logprobs_num
[
index
]
)
else
:
if
isinstance
(
obj
.
text
,
list
):
input_text
=
obj
.
text
[
index
]
rid
=
obj
.
rid
[
index
]
else
:
input_text
=
obj
.
text
rid
=
obj
.
rid
[
0
]
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
0
]
)
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
...
...
@@ -196,26 +198,26 @@ class TokenizerManager:
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
if
is_prefill
:
await
self
.
_wait_for_prefill_response
(
event
,
state
,
obj
,
request
,
rid
)
yield
input_ids
else
:
if
not
is_cache_for_prefill
:
async
for
response
in
self
.
_wait_for_response
(
event
,
state
,
obj
,
rid
,
request
):
yield
response
else
:
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
yield
input_ids
async
def
_handle_batch_request
(
self
,
obj
,
request
):
async
def
_handle_batch_request
(
self
,
obj
:
GenerateReqInput
,
request
):
batch_size
=
obj
.
batch_size
parallel_sample_num
=
obj
.
sampling_params
[
0
].
get
(
"n"
,
1
)
if
parallel_sample_num
!=
1
:
#
# s
end prefill requests
#
S
end prefill requests
to cache the common input
parallel_sample_num
+=
1
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
for
i
in
range
(
batch_size
):
async
for
input_id
in
self
.
_handle_single_request
(
obj
,
request
,
index
=
i
,
is_prefill
=
True
obj
,
request
,
index
=
i
,
is_
cache_for_
prefill
=
True
):
if
input_id_result
is
not
None
:
input_id_result
.
append
(
input_id
)
...
...
@@ -224,6 +226,7 @@ class TokenizerManager:
obj
.
input_ids
=
input_id_result
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
[
0
]
# First send out all requests
for
i
in
range
(
batch_size
):
for
j
in
range
(
parallel_sample_num
):
...
...
@@ -308,17 +311,15 @@ class TokenizerManager:
yield
output_list
def
_validate_input_length
(
self
,
input_ids
):
def
_validate_input_length
(
self
,
input_ids
:
List
[
int
]
):
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)."
)
def
_get_sampling_params
(
self
,
sampling_params_data
,
max_new_tokens
=
None
):
def
_get_sampling_params
(
self
,
sampling_params_data
:
dict
):
sampling_params
=
SamplingParams
(
**
sampling_params_data
)
if
max_new_tokens
is
not
None
:
sampling_params
.
max_new_tokens
=
max_new_tokens
if
sampling_params
.
max_new_tokens
!=
0
:
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
...
...
@@ -332,7 +333,14 @@ class TokenizerManager:
else
:
return
None
,
None
,
None
async
def
_wait_for_response
(
self
,
event
,
state
,
obj
,
rid
,
request
):
async
def
_wait_for_response
(
self
,
event
:
asyncio
.
Event
,
state
:
ReqState
,
obj
:
GenerateReqInput
,
rid
:
str
,
request
,
):
while
True
:
try
:
await
asyncio
.
wait_for
(
event
.
wait
(),
timeout
=
4
)
...
...
@@ -361,7 +369,14 @@ class TokenizerManager:
event
.
clear
()
yield
out
async
def
_wait_for_prefill_response
(
self
,
event
,
state
,
obj
,
request
,
rid
):
async
def
_wait_for_cache_prefill_response
(
self
,
event
:
asyncio
.
Event
,
state
:
ReqState
,
obj
:
GenerateReqInput
,
rid
:
str
,
request
,
):
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
...
...
@@ -380,7 +395,7 @@ class TokenizerManager:
req
=
FlushCacheReq
()
self
.
send_to_router
.
send_pyobj
(
req
)
def
abort_request
(
self
,
rid
):
def
abort_request
(
self
,
rid
:
str
):
if
rid
not
in
self
.
rid_to_state
:
return
del
self
.
rid_to_state
[
rid
]
...
...
@@ -426,7 +441,11 @@ class TokenizerManager:
state
.
event
.
set
()
def
convert_logprob_style
(
self
,
ret
,
return_logprob
,
top_logprobs_num
,
return_text_in_logprobs
self
,
ret
:
dict
,
return_logprob
:
bool
,
top_logprobs_num
:
int
,
return_text_in_logprobs
:
bool
,
):
if
return_logprob
:
ret
[
"meta_info"
][
"prefill_token_logprobs"
]
=
self
.
detokenize_logprob_tokens
(
...
...
@@ -450,7 +469,7 @@ class TokenizerManager:
)
return
ret
def
detokenize_logprob_tokens
(
self
,
token_logprobs
,
decode_to_text
):
def
detokenize_logprob_tokens
(
self
,
token_logprobs
,
decode_to_text
:
bool
):
if
not
decode_to_text
:
return
[(
logprob
,
token_id
,
None
)
for
logprob
,
token_id
in
token_logprobs
]
...
...
@@ -461,7 +480,7 @@ class TokenizerManager:
for
(
logprob
,
token_id
),
token_text
,
in
zip
(
token_logprobs
,
token_texts
)
]
def
detokenize_top_logprobs_tokens
(
self
,
top_logprobs
,
decode_to_text
):
def
detokenize_top_logprobs_tokens
(
self
,
top_logprobs
,
decode_to_text
:
bool
):
for
i
,
t
in
enumerate
(
top_logprobs
):
if
t
:
top_logprobs
[
i
]
=
self
.
detokenize_logprob_tokens
(
t
,
decode_to_text
)
...
...
python/sglang/test/test_programs.py
View file @
0736b270
...
...
@@ -118,7 +118,11 @@ def test_decode_json_regex():
s
+=
"}"
ret
=
decode_json
.
run
()
js_obj
=
json
.
loads
(
ret
[
"json_output"
])
try
:
js_obj
=
json
.
loads
(
ret
[
"json_output"
])
except
json
.
decoder
.
JSONDecodeError
:
print
(
ret
[
"json_output"
])
raise
assert
isinstance
(
js_obj
[
"name"
],
str
)
assert
isinstance
(
js_obj
[
"population"
],
int
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment