Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c17c5781
Unverified
Commit
c17c5781
authored
Nov 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 03, 2024
Browse files
Simplify tokenizer manager (#1904)
parent
916b3cdd
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
259 additions
and
441 deletions
+259
-441
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+0
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+66
-84
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+116
-279
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+52
-66
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-3
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+5
-0
test/srt/test_skip_tokenizer_init.py
test/srt/test_skip_tokenizer_init.py
+3
-0
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+11
-1
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+1
-0
No files found.
python/sglang/srt/managers/data_parallel_controller.py
View file @
c17c5781
...
...
@@ -24,7 +24,6 @@ import zmq
from
sglang.srt.managers.io_struct
import
(
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -152,7 +151,6 @@ class DataParallelController:
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
,
),
):
self
.
dispatching
(
recv_req
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
c17c5781
...
...
@@ -56,49 +56,47 @@ class GenerateReqInput:
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
self
.
is_single
=
Fals
e
# Derive the batch siz
e
if
self
.
text
is
not
None
:
if
isinstance
(
self
.
text
,
str
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
text
)
else
:
if
isinstance
(
self
.
input_ids
[
0
],
int
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
input_ids
)
# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
if
self
.
sampling_params
is
None
:
self
.
parallel_sample_num
=
1
elif
isinstance
(
self
.
sampling_params
,
dict
):
self
.
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
else
:
# isinstance(self.sampling_params, list):
self
.
parallel_sample_num
=
self
.
sampling_params
[
0
].
get
(
"n"
,
1
)
for
sp
in
self
.
sampling_params
:
# TODO cope with the case that the parallel_sample_num is different for different samples
assert
self
.
parallel_sample_num
==
sp
.
get
(
"n"
,
1
),
"The parallel_sample_num should be the same for all samples in sample params."
if
self
.
parallel_sample_num
>
1
:
if
self
.
is_single
:
self
.
is_single
=
False
if
self
.
text
is
not
None
:
self
.
text
=
[
self
.
text
]
if
self
.
input_ids
is
not
None
:
self
.
input_ids
=
[
self
.
input_ids
]
assert
all
(
self
.
parallel_sample_num
==
sampling_params
.
get
(
"n"
,
1
)
for
sampling_params
in
self
.
sampling_params
),
(
"The parallel_sample_num should be the same for all samples in sample params."
)
if
self
.
parallel_sample_num
>
1
and
self
.
is_single
:
self
.
is_single
=
False
if
self
.
text
is
not
None
:
self
.
text
=
[
self
.
text
]
if
self
.
input_ids
is
not
None
:
self
.
input_ids
=
[
self
.
input_ids
]
# Fill in default arguments
if
self
.
is_single
:
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{}
...
...
@@ -114,8 +112,8 @@ class GenerateReqInput:
if
self
.
parallel_sample_num
==
1
:
num
=
self
.
batch_size
else
:
#
The first bs samples are used for caching the prefix for
parallel
sampl
ing
num
=
self
.
batch_size
+
self
.
parallel_sample_num
*
self
.
batch_size
#
Expand
parallel
_
sampl
e_num
num
=
self
.
batch_size
*
self
.
parallel_sample_num
if
self
.
image_data
is
None
:
self
.
image_data
=
[
None
]
*
num
...
...
@@ -128,14 +126,11 @@ class GenerateReqInput:
self
.
sampling_params
=
[{}]
*
num
elif
not
isinstance
(
self
.
sampling_params
,
list
):
self
.
sampling_params
=
[
self
.
sampling_params
]
*
num
else
:
assert
self
.
parallel_sample_num
==
1
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
num
)]
else
:
assert
isinstance
(
self
.
rid
,
list
),
"The rid should be a list."
assert
self
.
parallel_sample_num
==
1
if
self
.
return_logprob
is
None
:
self
.
return_logprob
=
[
False
]
*
num
...
...
@@ -158,6 +153,26 @@ class GenerateReqInput:
else
:
assert
self
.
parallel_sample_num
==
1
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
def
__getitem__
(
self
,
i
):
return
GenerateReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
image_data
=
self
.
image_data
[
i
],
sampling_params
=
self
.
sampling_params
[
i
],
rid
=
self
.
rid
[
i
],
return_logprob
=
self
.
return_logprob
[
i
],
logprob_start_len
=
self
.
logprob_start_len
[
i
],
top_logprobs_num
=
self
.
top_logprobs_num
[
i
],
return_text_in_logprobs
=
self
.
return_text_in_logprobs
,
stream
=
self
.
stream
,
modalities
=
self
.
modalities
[
i
]
if
self
.
modalities
else
None
,
lora_path
=
self
.
lora_path
[
i
]
if
self
.
lora_path
is
not
None
else
None
,
)
@
dataclass
class
TokenizedGenerateReqInput
:
...
...
@@ -195,20 +210,29 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
# Derive the batch size
if
self
.
text
is
not
None
:
self
.
is_single
=
isinstance
(
self
.
text
,
str
)
if
isinstance
(
self
.
text
,
str
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
text
)
else
:
self
.
is_single
=
isinstance
(
self
.
input_ids
[
0
],
int
)
if
isinstance
(
self
.
input_ids
[
0
],
int
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
input_ids
)
# Fill in default arguments
if
self
.
is_single
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
...
...
@@ -216,73 +240,31 @@ class EmbeddingReqInput:
self
.
sampling_params
=
{}
self
.
sampling_params
[
"max_new_tokens"
]
=
1
else
:
# support select operation
self
.
batch_size
=
(
len
(
self
.
text
)
if
self
.
text
is
not
None
else
len
(
self
.
input_ids
)
)
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
self
.
batch_size
)]
else
:
if
no
t
isinstance
(
self
.
rid
,
list
)
:
raise
ValueError
(
"The rid should be a list."
)
asser
t
isinstance
(
self
.
rid
,
list
)
,
"The rid should be a list."
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
[{}]
*
self
.
batch_size
for
i
in
range
(
self
.
batch_size
):
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
1
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
@
dataclass
class
TokenizedEmbeddingReqInput
:
# The request id
rid
:
str
# The input text
input_text
:
str
# The input token ids
input_ids
:
List
[
int
]
# Dummy sampling params for compatibility
sampling_params
:
SamplingParams
RewardReqConv
=
Union
[
List
[
List
[
Dict
]],
List
[
Dict
],
str
,
List
[
str
]]
@
dataclass
class
RewardReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string.
conv
:
RewardReqConv
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
self
.
is_single
=
isinstance
(
self
.
conv
[
0
],
dict
)
if
self
.
is_single
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{}
self
.
sampling_params
[
"max_new_tokens"
]
=
1
else
:
# support select operation
self
.
batch_size
=
len
(
self
.
conv
)
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
self
.
batch_size
)]
else
:
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
[{}]
*
self
.
batch_size
for
i
in
range
(
self
.
batch_size
):
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
1
def
__getitem__
(
self
,
i
):
return
EmbeddingReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
sampling_params
=
self
.
sampling_params
[
i
],
rid
=
self
.
rid
[
i
],
)
@
dataclass
class
Tokenized
Reward
ReqInput
:
class
Tokenized
Embedding
ReqInput
:
# The request id
rid
:
str
# The input text
...
...
python/sglang/srt/managers/scheduler.py
View file @
c17c5781
...
...
@@ -43,7 +43,6 @@ from sglang.srt.managers.io_struct import (
ProfileReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
...
...
@@ -394,9 +393,7 @@ class Scheduler:
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
(
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
)
):
elif
isinstance
(
recv_req
,
TokenizedEmbeddingReqInput
):
self
.
handle_embedding_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
...
...
@@ -487,7 +484,7 @@ class Scheduler:
def
handle_embedding_request
(
self
,
recv_req
:
Union
[
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
],
recv_req
:
TokenizedEmbeddingReqInput
,
):
req
=
Req
(
recv_req
.
rid
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
c17c5781
...
...
@@ -16,6 +16,7 @@ limitations under the License.
"""TokenizerManager is a process that tokenizes the text."""
import
asyncio
import
copy
import
dataclasses
import
json
import
logging
...
...
@@ -51,11 +52,8 @@ from sglang.srt.managers.io_struct import (
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
ProfileReq
,
RewardReqConv
,
RewardReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
...
...
@@ -157,7 +155,7 @@ class TokenizerManager:
async
def
generate_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
if
self
.
to_create_loop
:
...
...
@@ -172,122 +170,54 @@ class TokenizerManager:
"Please add `--is-embedding` when launching the server or try another model."
)
obj
.
post_init
()
obj
.
normalize_batch_and_arguments
()
is_single
=
obj
.
is_single
if
is_single
:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
tokenized_obj
=
await
self
.
_tokenize_one_request
(
obj
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
async
for
response
in
self
.
_wait_one_response
(
obj
,
request
):
yield
response
else
:
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
yield
response
async
def
_
send_singl
e_request
(
async
def
_
tokenize_on
e_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
index
:
Optional
[
int
]
=
None
,
input_id_index
:
Optional
[
int
]
=
None
,
is_cache_for_prefill
:
Optional
[
bool
]
=
False
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
):
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
if
index
is
None
:
rid
=
obj
.
rid
if
isinstance
(
obj
,
RewardReqInput
):
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_text
=
obj
.
text
if
obj
.
text
is
not
None
else
None
input_ids
=
obj
.
input_ids
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
)
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
,
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
else
:
rid
=
obj
.
rid
[
index
]
if
isinstance
(
obj
,
RewardReqInput
):
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
[
input_id_index
])
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
input_id_index
]
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_text
=
(
obj
.
text
[
input_id_index
]
if
obj
.
text
is
not
None
else
None
)
input_ids
=
obj
.
input_ids
[
input_id_index
]
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
index
],
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
index
]
logprob_start_len
=
obj
.
logprob_start_len
[
index
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
index
]
self
.
_validate_input_length
(
input_ids
)
else
:
# A prefill request to cache the common prompt for parallel sampling
assert
self
.
is_generation
if
obj
.
text
is
not
None
:
if
isinstance
(
obj
.
text
,
list
):
input_text
=
obj
.
text
[
input_id_index
]
rid
=
obj
.
rid
[
index
]
else
:
input_text
=
obj
.
text
rid
=
obj
.
rid
[
0
]
if
self
.
tokenizer
is
not
None
:
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
assert
obj
.
input_ids
is
not
None
input_ids
=
obj
.
input_ids
if
isinstance
(
obj
.
input_ids
,
list
)
and
isinstance
(
obj
.
input_ids
[
0
],
list
):
# when obj["input_ids"] is List[List[int]]
input_ids
=
obj
.
input_ids
[
input_id_index
]
rid
=
obj
.
rid
[
index
]
else
:
input_ids
=
obj
.
input_ids
rid
=
obj
.
rid
[
0
]
else
:
input_text
=
None
if
isinstance
(
obj
.
input_ids
,
list
)
and
isinstance
(
obj
.
input_ids
[
0
],
list
):
# when obj["input_ids"] is List[List[int]]
input_ids
=
obj
.
input_ids
[
input_id_index
]
rid
=
obj
.
rid
[
index
]
else
:
input_ids
=
obj
.
input_ids
rid
=
obj
.
rid
[
0
]
"""Tokenize one request."""
# Tokenize
input_text
=
obj
.
text
if
obj
.
input_ids
is
None
:
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_ids
=
obj
.
input_ids
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
0
]
,
input_text
or
input_ids
,
obj
obj
.
image_data
,
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
# Send to the controller
if
self
.
is_generation
:
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)."
)
# Parse sampling parameters
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
# Build return object
if
isinstance
(
obj
,
GenerateReqInput
):
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
obj
.
rid
,
input_text
,
input_ids
,
image_inputs
,
...
...
@@ -296,219 +226,126 @@ class TokenizerManager:
logprob_start_len
,
top_logprobs_num
,
obj
.
stream
,
(
obj
.
lora_path
[
input_id_index
]
if
isinstance
(
obj
.
lora_path
,
list
)
else
obj
.
lora_path
),
obj
.
lora_path
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
rid
,
input_text
,
input_ids
,
sampling_params
,
)
else
:
assert
isinstance
(
obj
,
RewardReqInput
)
tokenized_obj
=
TokenizedRewardReqInput
(
rid
,
obj
.
rid
,
input_text
,
input_ids
,
sampling_params
,
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
return
rid
,
input_ids
return
tokenized_obj
async
def
_
handle_single_request
(
async
def
_
wait_one_response
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
index
:
Optional
[
int
]
=
None
,
input_id_index
:
Optional
[
int
]
=
None
,
is_cache_for_prefill
:
Optional
[
bool
]
=
False
,
):
rid
,
input_ids
=
await
self
.
_send_single_request
(
obj
,
index
,
input_id_index
=
input_id_index
,
is_cache_for_prefill
=
is_cache_for_prefill
,
)
# Recv results
"""Wait for the response of one request."""
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
if
not
is_cache_for_prefill
:
async
for
response
in
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
):
yield
response
else
:
await
state
.
event
.
wait
()
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
yield
input_ids
async
def
_handle_batch_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
batch_size
=
obj
.
batch_size
if
self
.
is_generation
:
parallel_sample_num
=
obj
.
parallel_sample_num
if
parallel_sample_num
!=
1
:
# Send prefill requests to cache the common prefix
parallel_sample_num
+=
1
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
for
i
in
range
(
batch_size
):
async
for
input_id
in
self
.
_handle_single_request
(
obj
,
request
,
index
=
i
,
input_id_index
=
i
,
is_cache_for_prefill
=
True
,
):
if
input_id_result
is
not
None
:
input_id_result
.
append
(
input_id
)
if
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
else
:
parallel_sample_num
=
1
# First send out all requests
generators
=
[]
for
i
in
range
(
batch_size
):
for
j
in
range
(
parallel_sample_num
):
if
j
==
0
and
parallel_sample_num
!=
1
:
continue
index
=
i
*
parallel_sample_num
+
j
if
parallel_sample_num
!=
1
:
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index
+=
batch_size
-
1
-
i
self
.
rid_to_state
[
obj
.
rid
]
=
state
rid
,
_
=
await
self
.
_send_single_request
(
obj
,
index
,
input_id_index
=
i
,
is_cache_for_prefill
=
False
)
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
generators
.
append
(
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
,
index
=
index
,
response_index
=
len
(
generators
),
)
)
# Then process the responses based on streaming option
is_stream
=
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
tasks
=
[
asyncio
.
create_task
(
gen
.
__anext__
())
for
gen
in
generators
]
output_list
=
[
None
]
*
len
(
tasks
)
# Fetch results
while
tasks
:
done
,
_
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
cur_index
=
tasks
.
index
(
task
)
try
:
result
=
task
.
result
()
if
is_stream
:
yield
result
else
:
output_list
[
result
[
"index"
]]
=
result
tasks
[
cur_index
]
=
asyncio
.
create_task
(
generators
[
cur_index
].
__anext__
()
)
except
StopAsyncIteration
:
del
generators
[
cur_index
]
del
tasks
[
cur_index
]
if
not
is_stream
:
yield
output_list
def
_validate_input_length
(
self
,
input_ids
:
List
[
int
]):
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)."
)
def
_get_sampling_params
(
self
,
sampling_params_data
:
dict
):
sampling_params
=
SamplingParams
(
**
sampling_params_data
)
if
sampling_params
.
max_new_tokens
!=
0
:
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
return
sampling_params
def
_apply_chat_template
(
self
,
conv
:
RewardReqConv
)
->
Union
[
str
,
List
[
str
]]:
if
isinstance
(
conv
,
str
):
return
conv
elif
isinstance
(
conv
,
list
):
if
isinstance
(
conv
[
0
],
str
):
return
conv
else
:
return
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
async
def
_wait_for_response
(
self
,
state
:
ReqState
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
rid
:
str
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
index
:
Optional
[
int
]
=
None
,
response_index
:
int
=
0
,
):
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
[
obj
.
rid
]
if
obj
.
is_single
else
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
self
.
abort_request
(
obj
.
rid
)
raise
ValueError
(
f
"Abort request
{
obj
.
rid
}
"
)
continue
if
self
.
is_generation
:
if
isinstance
(
obj
,
GenerateReqInput
)
:
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
if
index
is
None
else
obj
.
return_logprob
[
index
],
(
obj
.
top_logprobs_num
if
index
is
None
else
obj
.
top_logprobs_num
[
index
]
),
obj
.
return_logprob
,
obj
.
top_logprobs_num
,
obj
.
return_text_in_logprobs
,
)
else
:
# isinstance(obj, (EmbeddingReqInput,
RewardReqInput
))
else
:
# isinstance(obj, (EmbeddingReqInput,))
out
=
state
.
out_list
[
-
1
]
out
[
"index"
]
=
response_index
state
.
out_list
=
[]
if
state
.
finished
:
# Log requests
if
self
.
server_args
.
log_requests
:
# Log requests
logger
.
info
(
f
"in=
{
obj
}
, out=
{
out
}
"
)
del
self
.
rid_to_state
[
rid
]
del
self
.
rid_to_state
[
obj
.
rid
]
yield
out
break
state
.
event
.
clear
()
yield
out
async
def
_handle_batch_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
batch_size
=
obj
.
batch_size
generators
=
[]
rids
=
[]
if
getattr
(
obj
,
"parallel_sample_num"
,
1
)
==
1
:
# Send all requests
for
i
in
range
(
batch_size
):
tmp_obj
=
obj
[
i
]
tokenized_obj
=
await
self
.
_tokenize_one_request
(
tmp_obj
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
generators
.
append
(
self
.
_wait_one_response
(
tmp_obj
,
request
))
rids
.
append
(
tmp_obj
.
rid
)
else
:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
# Tokenize all requests
objs
=
[
obj
[
i
]
for
i
in
range
(
batch_size
)]
tokenized_objs
=
await
asyncio
.
gather
(
*
(
self
.
_tokenize_one_request
(
obj
)
for
obj
in
objs
))
# Cache the common prefix for parallel sampling
for
i
in
range
(
batch_size
):
tmp_obj
=
copy
.
copy
(
objs
[
i
])
tokenized_obj
=
copy
.
copy
(
tokenized_objs
[
i
])
tokenized_obj
.
rid
=
tmp_obj
.
regenerate_rid
()
tokenized_obj
.
sampling_params
=
copy
.
copy
(
tokenized_obj
.
sampling_params
)
tokenized_obj
.
sampling_params
.
max_new_tokens
=
0
tokenized_obj
.
stream
=
False
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
await
self
.
_wait_one_response
(
tmp_obj
,
request
).
__anext__
()
# Expand requests, assign new rids for them, and send them
for
i
in
range
(
batch_size
):
for
_
in
range
(
obj
.
parallel_sample_num
):
tmp_obj
=
copy
.
copy
(
objs
[
i
])
tokenized_obj
=
copy
.
copy
(
tokenized_objs
[
i
])
tokenized_obj
.
rid
=
tmp_obj
.
regenerate_rid
()
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
generators
.
append
(
self
.
_wait_one_response
(
tmp_obj
,
request
))
rids
.
append
(
tmp_obj
.
rid
)
# Wait for all requests
is_stream
=
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
if
not
is_stream
:
outputs
=
await
asyncio
.
gather
(
*
(
gen
.
__anext__
()
for
gen
in
generators
))
yield
outputs
else
:
rid_to_index
=
{
rid
:
i
for
i
,
rid
in
enumerate
(
rids
)}
task_map
=
{
asyncio
.
create_task
(
gen
.
__anext__
()):
gen
for
gen
in
generators
}
while
task_map
:
done
,
_
=
await
asyncio
.
wait
(
task_map
.
keys
(),
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
gen
=
task_map
.
pop
(
task
)
try
:
result
=
task
.
result
()
result
[
"index"
]
=
rid_to_index
[
result
[
"meta_info"
][
"id"
]]
yield
result
new_task
=
asyncio
.
create_task
(
gen
.
__anext__
())
task_map
[
new_task
]
=
gen
except
StopAsyncIteration
:
pass
def
flush_cache
(
self
):
req
=
FlushCacheReq
()
self
.
send_to_scheduler
.
send_pyobj
(
req
)
...
...
python/sglang/srt/openai_api/adapter.py
View file @
c17c5781
...
...
@@ -71,6 +71,7 @@ from sglang.srt.openai_api.protocol import (
TopLogprob
,
UsageInfo
,
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -314,6 +315,8 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
)
except
Exception
as
e
:
logger
.
error
(
f
"error:
{
get_exception_traceback
()
}
"
)
responses
=
[]
error_json
=
{
"id"
:
f
"batch_req_
{
uuid
.
uuid4
()
}
"
,
"custom_id"
:
request_data
.
get
(
"custom_id"
),
...
...
@@ -363,7 +366,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
}
except
Exception
as
e
:
logger
.
error
(
"error
in SGLang:"
,
e
)
logger
.
error
(
f
"error
:
{
e
}
"
)
# Update batch status to "failed"
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"failed"
...
...
@@ -469,80 +472,67 @@ async def v1_retrieve_file_content(file_id: str):
def
v1_generate_request
(
all_requests
:
List
[
CompletionRequest
],
request_ids
:
List
[
str
]
=
None
):
if
len
(
all_requests
)
>
1
:
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
assert
(
type
(
request
.
prompt
)
is
first_prompt_type
),
"All prompts must be of the same type in file input settings"
if
request
.
n
>
1
:
raise
ValueError
(
"Parallel sampling is not supported for completions from files"
)
prompts
=
[]
sampling_params_list
=
[]
return_logprobs
=
[]
logprob_start_lens
=
[]
top_logprobs_nums
=
[]
# NOTE: with openai API, the prompt's logprobs are always not computed
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
assert
(
type
(
request
.
prompt
)
is
first_prompt_type
),
"All prompts must be of the same type in file input settings"
if
len
(
all_requests
)
>
1
and
request
.
n
>
1
:
raise
ValueError
(
"Parallel sampling is not supported for completions from files"
)
# NOTE: with openai API, the prompt's logprobs are always not computed
if
request
.
echo
and
request
.
logprobs
:
logger
.
warning
(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use
SGLang /request
API."
"To compute logprobs of input prompt, please use
the native /generate
API."
)
for
request
in
all_requests
:
prompts
.
append
(
request
.
prompt
)
sampling_params_list
.
append
(
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"min_new_tokens"
:
request
.
min_tokens
,
"stop"
:
request
.
stop
,
"stop_token_ids"
:
request
.
stop_token_ids
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
"no_stop_trim"
:
request
.
no_stop_trim
,
}
)
return_logprobs
.
append
(
request
.
logprobs
is
not
None
and
request
.
logprobs
>
0
)
logprob_start_lens
.
append
(
-
1
)
top_logprobs_nums
.
append
(
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
)
sampling_params
=
[]
if
isinstance
(
request
.
no_stop_trim
,
list
):
num_reqs
=
len
(
request
.
prompt
)
else
:
num_reqs
=
1
for
i
in
range
(
num_reqs
):
sampling_params
.
append
(
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"min_new_tokens"
:
request
.
min_tokens
,
"stop"
:
request
.
stop
,
"stop_token_ids"
:
request
.
stop_token_ids
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
"no_stop_trim"
:
(
request
.
no_stop_trim
if
not
isinstance
(
request
.
no_stop_trim
,
list
)
else
request
.
no_stop_trim
[
i
]
),
}
)
if
num_reqs
==
1
:
sampling_params_list
.
append
(
sampling_params
[
0
])
else
:
sampling_params_list
.
append
(
sampling_params
)
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
prompts
[
0
][
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
[
0
]}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
[
0
]}
sampling_params_list
=
sampling_params_list
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
return_logprobs
=
return_logprobs
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
else
:
if
isinstance
(
prompts
[
0
],
str
):
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
prompts
[
0
][
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
...
...
@@ -558,9 +548,7 @@ def v1_generate_request(
rid
=
request_ids
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
return
adapted_request
,
all_requests
if
len
(
all_requests
)
>
1
else
all_requests
[
0
]
def
v1_generate_response
(
request
,
ret
,
tokenizer_manager
,
to_file
=
False
):
...
...
@@ -595,7 +583,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
if
isinstance
(
request
,
list
)
and
request
[
idx
].
echo
:
echo
=
True
text
=
request
[
idx
].
prompt
+
text
if
(
not
isinstance
(
request
,
list
)
)
and
echo
:
if
echo
and
not
isinstance
(
request
,
list
):
prompt_index
=
idx
//
request
.
n
text
=
prompts
[
prompt_index
]
+
text
...
...
@@ -709,7 +697,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
index
=
content
[
"index"
]
index
=
content
.
get
(
"index"
,
0
)
stream_buffer
=
stream_buffers
.
get
(
index
,
""
)
n_prev_token
=
n_prev_tokens
.
get
(
index
,
0
)
...
...
@@ -945,19 +933,18 @@ def v1_chat_generate_request(
sampling_params_list
.
append
(
sampling_params
)
image_data_list
.
append
(
image_data
)
modalities_list
.
ext
end
(
modalities
)
modalities_list
.
app
end
(
modalities
)
if
len
(
all_requests
)
==
1
:
input_ids
=
input_ids
[
0
]
if
isinstance
(
input_ids
,
str
):
prompt_kwargs
=
{
"text"
:
input_ids
}
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
[
0
]}
else
:
prompt_kwargs
=
{
"input_ids"
:
input_ids
}
prompt_kwargs
=
{
"input_ids"
:
input_ids
[
0
]
}
sampling_params_list
=
sampling_params_list
[
0
]
image_data_list
=
image_data_list
[
0
]
return_logprobs
=
return_logprobs
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
modalities_list
=
modalities_list
[
:
1
]
modalities_list
=
modalities_list
[
0
]
else
:
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
}
...
...
@@ -976,9 +963,8 @@ def v1_chat_generate_request(
rid
=
request_ids
,
modalities
=
modalities_list
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
return
adapted_request
,
all_requests
if
len
(
all_requests
)
>
1
else
all_requests
[
0
]
def
v1_chat_generate_response
(
request
,
ret
,
to_file
=
False
,
cache_report
=
False
):
...
...
@@ -1116,7 +1102,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
index
=
content
[
"index"
]
index
=
content
.
get
(
"index"
,
0
)
is_first
=
is_firsts
.
get
(
index
,
True
)
stream_buffer
=
stream_buffers
.
get
(
index
,
""
)
...
...
python/sglang/srt/server.py
View file @
c17c5781
...
...
@@ -53,7 +53,6 @@ from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from
sglang.srt.managers.io_struct
import
(
EmbeddingReqInput
,
GenerateReqInput
,
RewardReqInput
,
UpdateWeightReqInput
,
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
...
...
@@ -91,7 +90,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app
=
FastAPI
()
tokenizer_manager
=
None
tokenizer_manager
:
TokenizerManager
=
None
app
.
add_middleware
(
CORSMiddleware
,
...
...
@@ -254,7 +253,7 @@ app.post("/encode")(encode_request)
app
.
put
(
"/encode"
)(
encode_request
)
async
def
judge_request
(
obj
:
Reward
ReqInput
,
request
:
Request
):
async
def
judge_request
(
obj
:
Embedding
ReqInput
,
request
:
Request
):
"""Handle a reward model request."""
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
...
...
test/srt/run_suite.py
View file @
c17c5781
...
...
@@ -8,7 +8,7 @@ suites = {
"models/test_embedding_models.py"
,
"models/test_generation_models.py"
,
"models/test_lora.py"
,
"models/test_reward_models.py"
,
#
"models/test_reward_models.py",
"sampling/penaltylib"
,
"test_chunked_prefill.py"
,
"test_double_sparsity.py"
,
...
...
test/srt/test_openai_server.py
View file @
c17c5781
"""
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
"""
import
json
import
time
import
unittest
...
...
test/srt/test_skip_tokenizer_init.py
View file @
c17c5781
"""
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
"""
import
json
import
unittest
...
...
test/srt/test_srt_endpoint.py
View file @
c17c5781
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_parallel_sample
"""
import
json
...
...
@@ -36,11 +37,17 @@ class TestSRTEndpoint(unittest.TestCase):
return_text
=
False
,
n
=
1
,
stream
=
False
,
batch
=
False
,
):
if
batch
:
text
=
[
"The capital of France is"
]
else
:
text
=
"The capital of France is"
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"text"
:
text
,
"sampling_params"
:
{
"temperature"
:
0
if
n
==
1
else
0.5
,
"max_new_tokens"
:
16
,
...
...
@@ -67,6 +74,9 @@ class TestSRTEndpoint(unittest.TestCase):
def
test_simple_decode
(
self
):
self
.
run_decode
()
def
test_simple_decode_batch
(
self
):
self
.
run_decode
(
batch
=
True
)
def
test_parallel_sample
(
self
):
self
.
run_decode
(
n
=
3
)
...
...
test/srt/test_vision_openai_server.py
View file @
c17c5781
"""
Usage:
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_multi_images_chat_completion
"""
import
base64
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment