Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c17c5781
"vscode:/vscode.git/clone" did not exist on "14f4af8f5b9b2c097f3ce9cc082525e03d49e318"
Unverified
Commit
c17c5781
authored
Nov 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 03, 2024
Browse files
Simplify tokenizer manager (#1904)
parent
916b3cdd
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
259 additions
and
441 deletions
+259
-441
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+0
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+66
-84
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+116
-279
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+52
-66
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-3
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+5
-0
test/srt/test_skip_tokenizer_init.py
test/srt/test_skip_tokenizer_init.py
+3
-0
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+11
-1
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+1
-0
No files found.
python/sglang/srt/managers/data_parallel_controller.py
View file @
c17c5781
...
...
@@ -24,7 +24,6 @@ import zmq
from
sglang.srt.managers.io_struct
import
(
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -152,7 +151,6 @@ class DataParallelController:
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
,
),
):
self
.
dispatching
(
recv_req
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
c17c5781
...
...
@@ -56,49 +56,47 @@ class GenerateReqInput:
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
self
.
is_single
=
Fals
e
# Derive the batch siz
e
if
self
.
text
is
not
None
:
if
isinstance
(
self
.
text
,
str
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
text
)
else
:
if
isinstance
(
self
.
input_ids
[
0
],
int
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
input_ids
)
# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
if
self
.
sampling_params
is
None
:
self
.
parallel_sample_num
=
1
elif
isinstance
(
self
.
sampling_params
,
dict
):
self
.
parallel_sample_num
=
self
.
sampling_params
.
get
(
"n"
,
1
)
else
:
# isinstance(self.sampling_params, list):
self
.
parallel_sample_num
=
self
.
sampling_params
[
0
].
get
(
"n"
,
1
)
for
sp
in
self
.
sampling_params
:
# TODO cope with the case that the parallel_sample_num is different for different samples
assert
self
.
parallel_sample_num
==
sp
.
get
(
"n"
,
1
),
"The parallel_sample_num should be the same for all samples in sample params."
if
self
.
parallel_sample_num
>
1
:
if
self
.
is_single
:
self
.
is_single
=
False
if
self
.
text
is
not
None
:
self
.
text
=
[
self
.
text
]
if
self
.
input_ids
is
not
None
:
self
.
input_ids
=
[
self
.
input_ids
]
assert
all
(
self
.
parallel_sample_num
==
sampling_params
.
get
(
"n"
,
1
)
for
sampling_params
in
self
.
sampling_params
),
(
"The parallel_sample_num should be the same for all samples in sample params."
)
if
self
.
parallel_sample_num
>
1
and
self
.
is_single
:
self
.
is_single
=
False
if
self
.
text
is
not
None
:
self
.
text
=
[
self
.
text
]
if
self
.
input_ids
is
not
None
:
self
.
input_ids
=
[
self
.
input_ids
]
# Fill in default arguments
if
self
.
is_single
:
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{}
...
...
@@ -114,8 +112,8 @@ class GenerateReqInput:
if
self
.
parallel_sample_num
==
1
:
num
=
self
.
batch_size
else
:
#
The first bs samples are used for caching the prefix for
parallel
sampl
ing
num
=
self
.
batch_size
+
self
.
parallel_sample_num
*
self
.
batch_size
#
Expand
parallel
_
sampl
e_num
num
=
self
.
batch_size
*
self
.
parallel_sample_num
if
self
.
image_data
is
None
:
self
.
image_data
=
[
None
]
*
num
...
...
@@ -128,14 +126,11 @@ class GenerateReqInput:
self
.
sampling_params
=
[{}]
*
num
elif
not
isinstance
(
self
.
sampling_params
,
list
):
self
.
sampling_params
=
[
self
.
sampling_params
]
*
num
else
:
assert
self
.
parallel_sample_num
==
1
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
num
)]
else
:
assert
isinstance
(
self
.
rid
,
list
),
"The rid should be a list."
assert
self
.
parallel_sample_num
==
1
if
self
.
return_logprob
is
None
:
self
.
return_logprob
=
[
False
]
*
num
...
...
@@ -158,6 +153,26 @@ class GenerateReqInput:
else
:
assert
self
.
parallel_sample_num
==
1
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
def
__getitem__
(
self
,
i
):
return
GenerateReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
image_data
=
self
.
image_data
[
i
],
sampling_params
=
self
.
sampling_params
[
i
],
rid
=
self
.
rid
[
i
],
return_logprob
=
self
.
return_logprob
[
i
],
logprob_start_len
=
self
.
logprob_start_len
[
i
],
top_logprobs_num
=
self
.
top_logprobs_num
[
i
],
return_text_in_logprobs
=
self
.
return_text_in_logprobs
,
stream
=
self
.
stream
,
modalities
=
self
.
modalities
[
i
]
if
self
.
modalities
else
None
,
lora_path
=
self
.
lora_path
[
i
]
if
self
.
lora_path
is
not
None
else
None
,
)
@
dataclass
class
TokenizedGenerateReqInput
:
...
...
@@ -195,20 +210,29 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
# Derive the batch size
if
self
.
text
is
not
None
:
self
.
is_single
=
isinstance
(
self
.
text
,
str
)
if
isinstance
(
self
.
text
,
str
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
text
)
else
:
self
.
is_single
=
isinstance
(
self
.
input_ids
[
0
],
int
)
if
isinstance
(
self
.
input_ids
[
0
],
int
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
input_ids
)
# Fill in default arguments
if
self
.
is_single
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
...
...
@@ -216,73 +240,31 @@ class EmbeddingReqInput:
self
.
sampling_params
=
{}
self
.
sampling_params
[
"max_new_tokens"
]
=
1
else
:
# support select operation
self
.
batch_size
=
(
len
(
self
.
text
)
if
self
.
text
is
not
None
else
len
(
self
.
input_ids
)
)
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
self
.
batch_size
)]
else
:
if
no
t
isinstance
(
self
.
rid
,
list
)
:
raise
ValueError
(
"The rid should be a list."
)
asser
t
isinstance
(
self
.
rid
,
list
)
,
"The rid should be a list."
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
[{}]
*
self
.
batch_size
for
i
in
range
(
self
.
batch_size
):
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
1
def
regenerate_rid
(
self
):
self
.
rid
=
uuid
.
uuid4
().
hex
return
self
.
rid
@
dataclass
class
TokenizedEmbeddingReqInput
:
# The request id
rid
:
str
# The input text
input_text
:
str
# The input token ids
input_ids
:
List
[
int
]
# Dummy sampling params for compatibility
sampling_params
:
SamplingParams
RewardReqConv
=
Union
[
List
[
List
[
Dict
]],
List
[
Dict
],
str
,
List
[
str
]]
@
dataclass
class
RewardReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string.
conv
:
RewardReqConv
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
self
.
is_single
=
isinstance
(
self
.
conv
[
0
],
dict
)
if
self
.
is_single
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{}
self
.
sampling_params
[
"max_new_tokens"
]
=
1
else
:
# support select operation
self
.
batch_size
=
len
(
self
.
conv
)
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
self
.
batch_size
)]
else
:
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
[{}]
*
self
.
batch_size
for
i
in
range
(
self
.
batch_size
):
self
.
sampling_params
[
i
][
"max_new_tokens"
]
=
1
def
__getitem__
(
self
,
i
):
return
EmbeddingReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
sampling_params
=
self
.
sampling_params
[
i
],
rid
=
self
.
rid
[
i
],
)
@
dataclass
class
Tokenized
Reward
ReqInput
:
class
Tokenized
Embedding
ReqInput
:
# The request id
rid
:
str
# The input text
...
...
python/sglang/srt/managers/scheduler.py
View file @
c17c5781
...
...
@@ -43,7 +43,6 @@ from sglang.srt.managers.io_struct import (
ProfileReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
...
...
@@ -394,9 +393,7 @@ class Scheduler:
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
elif
isinstance
(
recv_req
,
(
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
)
):
elif
isinstance
(
recv_req
,
TokenizedEmbeddingReqInput
):
self
.
handle_embedding_request
(
recv_req
)
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
...
...
@@ -487,7 +484,7 @@ class Scheduler:
def
handle_embedding_request
(
self
,
recv_req
:
Union
[
TokenizedEmbeddingReqInput
,
TokenizedRewardReqInput
],
recv_req
:
TokenizedEmbeddingReqInput
,
):
req
=
Req
(
recv_req
.
rid
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
c17c5781
...
...
@@ -16,6 +16,7 @@ limitations under the License.
"""TokenizerManager is a process that tokenizes the text."""
import
asyncio
import
copy
import
dataclasses
import
json
import
logging
...
...
@@ -51,11 +52,8 @@ from sglang.srt.managers.io_struct import (
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
ProfileReq
,
RewardReqConv
,
RewardReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedRewardReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqOutput
,
)
...
...
@@ -157,7 +155,7 @@ class TokenizerManager:
async
def
generate_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
if
self
.
to_create_loop
:
...
...
@@ -172,122 +170,54 @@ class TokenizerManager:
"Please add `--is-embedding` when launching the server or try another model."
)
obj
.
post_init
()
obj
.
normalize_batch_and_arguments
()
is_single
=
obj
.
is_single
if
is_single
:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
tokenized_obj
=
await
self
.
_tokenize_one_request
(
obj
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
async
for
response
in
self
.
_wait_one_response
(
obj
,
request
):
yield
response
else
:
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
yield
response
async
def
_
send_singl
e_request
(
async
def
_
tokenize_on
e_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
index
:
Optional
[
int
]
=
None
,
input_id_index
:
Optional
[
int
]
=
None
,
is_cache_for_prefill
:
Optional
[
bool
]
=
False
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
):
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
if
index
is
None
:
rid
=
obj
.
rid
if
isinstance
(
obj
,
RewardReqInput
):
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_text
=
obj
.
text
if
obj
.
text
is
not
None
else
None
input_ids
=
obj
.
input_ids
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
)
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
,
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
else
:
rid
=
obj
.
rid
[
index
]
if
isinstance
(
obj
,
RewardReqInput
):
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
[
input_id_index
])
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
input_id_index
]
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_text
=
(
obj
.
text
[
input_id_index
]
if
obj
.
text
is
not
None
else
None
)
input_ids
=
obj
.
input_ids
[
input_id_index
]
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
index
],
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
index
]
logprob_start_len
=
obj
.
logprob_start_len
[
index
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
index
]
self
.
_validate_input_length
(
input_ids
)
else
:
# A prefill request to cache the common prompt for parallel sampling
assert
self
.
is_generation
if
obj
.
text
is
not
None
:
if
isinstance
(
obj
.
text
,
list
):
input_text
=
obj
.
text
[
input_id_index
]
rid
=
obj
.
rid
[
index
]
else
:
input_text
=
obj
.
text
rid
=
obj
.
rid
[
0
]
if
self
.
tokenizer
is
not
None
:
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
assert
obj
.
input_ids
is
not
None
input_ids
=
obj
.
input_ids
if
isinstance
(
obj
.
input_ids
,
list
)
and
isinstance
(
obj
.
input_ids
[
0
],
list
):
# when obj["input_ids"] is List[List[int]]
input_ids
=
obj
.
input_ids
[
input_id_index
]
rid
=
obj
.
rid
[
index
]
else
:
input_ids
=
obj
.
input_ids
rid
=
obj
.
rid
[
0
]
else
:
input_text
=
None
if
isinstance
(
obj
.
input_ids
,
list
)
and
isinstance
(
obj
.
input_ids
[
0
],
list
):
# when obj["input_ids"] is List[List[int]]
input_ids
=
obj
.
input_ids
[
input_id_index
]
rid
=
obj
.
rid
[
index
]
else
:
input_ids
=
obj
.
input_ids
rid
=
obj
.
rid
[
0
]
"""Tokenize one request."""
# Tokenize
input_text
=
obj
.
text
if
obj
.
input_ids
is
None
:
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_ids
=
obj
.
input_ids
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
if
self
.
is_generation
:
image_inputs
=
await
self
.
image_processor
.
process_images_async
(
obj
.
image_data
[
0
]
,
input_text
or
input_ids
,
obj
obj
.
image_data
,
input_text
or
input_ids
,
obj
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
input_ids
=
image_inputs
[
"input_ids"
]
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
# Send to the controller
if
self
.
is_generation
:
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)."
)
# Parse sampling parameters
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
)
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
# Build return object
if
isinstance
(
obj
,
GenerateReqInput
):
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
obj
.
rid
,
input_text
,
input_ids
,
image_inputs
,
...
...
@@ -296,219 +226,126 @@ class TokenizerManager:
logprob_start_len
,
top_logprobs_num
,
obj
.
stream
,
(
obj
.
lora_path
[
input_id_index
]
if
isinstance
(
obj
.
lora_path
,
list
)
else
obj
.
lora_path
),
obj
.
lora_path
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
rid
,
input_text
,
input_ids
,
sampling_params
,
)
else
:
assert
isinstance
(
obj
,
RewardReqInput
)
tokenized_obj
=
TokenizedRewardReqInput
(
rid
,
obj
.
rid
,
input_text
,
input_ids
,
sampling_params
,
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
return
rid
,
input_ids
return
tokenized_obj
async
def
_
handle_single_request
(
async
def
_
wait_one_response
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
index
:
Optional
[
int
]
=
None
,
input_id_index
:
Optional
[
int
]
=
None
,
is_cache_for_prefill
:
Optional
[
bool
]
=
False
,
):
rid
,
input_ids
=
await
self
.
_send_single_request
(
obj
,
index
,
input_id_index
=
input_id_index
,
is_cache_for_prefill
=
is_cache_for_prefill
,
)
# Recv results
"""Wait for the response of one request."""
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
if
not
is_cache_for_prefill
:
async
for
response
in
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
):
yield
response
else
:
await
state
.
event
.
wait
()
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
yield
input_ids
async
def
_handle_batch_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
batch_size
=
obj
.
batch_size
if
self
.
is_generation
:
parallel_sample_num
=
obj
.
parallel_sample_num
if
parallel_sample_num
!=
1
:
# Send prefill requests to cache the common prefix
parallel_sample_num
+=
1
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
for
i
in
range
(
batch_size
):
async
for
input_id
in
self
.
_handle_single_request
(
obj
,
request
,
index
=
i
,
input_id_index
=
i
,
is_cache_for_prefill
=
True
,
):
if
input_id_result
is
not
None
:
input_id_result
.
append
(
input_id
)
if
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
else
:
parallel_sample_num
=
1
# First send out all requests
generators
=
[]
for
i
in
range
(
batch_size
):
for
j
in
range
(
parallel_sample_num
):
if
j
==
0
and
parallel_sample_num
!=
1
:
continue
index
=
i
*
parallel_sample_num
+
j
if
parallel_sample_num
!=
1
:
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index
+=
batch_size
-
1
-
i
self
.
rid_to_state
[
obj
.
rid
]
=
state
rid
,
_
=
await
self
.
_send_single_request
(
obj
,
index
,
input_id_index
=
i
,
is_cache_for_prefill
=
False
)
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
generators
.
append
(
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
,
index
=
index
,
response_index
=
len
(
generators
),
)
)
# Then process the responses based on streaming option
is_stream
=
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
tasks
=
[
asyncio
.
create_task
(
gen
.
__anext__
())
for
gen
in
generators
]
output_list
=
[
None
]
*
len
(
tasks
)
# Fetch results
while
tasks
:
done
,
_
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
cur_index
=
tasks
.
index
(
task
)
try
:
result
=
task
.
result
()
if
is_stream
:
yield
result
else
:
output_list
[
result
[
"index"
]]
=
result
tasks
[
cur_index
]
=
asyncio
.
create_task
(
generators
[
cur_index
].
__anext__
()
)
except
StopAsyncIteration
:
del
generators
[
cur_index
]
del
tasks
[
cur_index
]
if
not
is_stream
:
yield
output_list
def
_validate_input_length
(
self
,
input_ids
:
List
[
int
]):
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)."
)
def
_get_sampling_params
(
self
,
sampling_params_data
:
dict
):
sampling_params
=
SamplingParams
(
**
sampling_params_data
)
if
sampling_params
.
max_new_tokens
!=
0
:
sampling_params
.
normalize
(
self
.
tokenizer
)
sampling_params
.
verify
()
return
sampling_params
def
_apply_chat_template
(
self
,
conv
:
RewardReqConv
)
->
Union
[
str
,
List
[
str
]]:
if
isinstance
(
conv
,
str
):
return
conv
elif
isinstance
(
conv
,
list
):
if
isinstance
(
conv
[
0
],
str
):
return
conv
else
:
return
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
async
def
_wait_for_response
(
self
,
state
:
ReqState
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
,
RewardReqInput
],
rid
:
str
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
index
:
Optional
[
int
]
=
None
,
response_index
:
int
=
0
,
):
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
[
obj
.
rid
]
if
obj
.
is_single
else
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
self
.
abort_request
(
obj
.
rid
)
raise
ValueError
(
f
"Abort request
{
obj
.
rid
}
"
)
continue
if
self
.
is_generation
:
if
isinstance
(
obj
,
GenerateReqInput
)
:
out
=
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
if
index
is
None
else
obj
.
return_logprob
[
index
],
(
obj
.
top_logprobs_num
if
index
is
None
else
obj
.
top_logprobs_num
[
index
]
),
obj
.
return_logprob
,
obj
.
top_logprobs_num
,
obj
.
return_text_in_logprobs
,
)
else
:
# isinstance(obj, (EmbeddingReqInput,
RewardReqInput
))
else
:
# isinstance(obj, (EmbeddingReqInput,))
out
=
state
.
out_list
[
-
1
]
out
[
"index"
]
=
response_index
state
.
out_list
=
[]
if
state
.
finished
:
# Log requests
if
self
.
server_args
.
log_requests
:
# Log requests
logger
.
info
(
f
"in=
{
obj
}
, out=
{
out
}
"
)
del
self
.
rid_to_state
[
rid
]
del
self
.
rid_to_state
[
obj
.
rid
]
yield
out
break
state
.
event
.
clear
()
yield
out
async
def
_handle_batch_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
batch_size
=
obj
.
batch_size
generators
=
[]
rids
=
[]
if
getattr
(
obj
,
"parallel_sample_num"
,
1
)
==
1
:
# Send all requests
for
i
in
range
(
batch_size
):
tmp_obj
=
obj
[
i
]
tokenized_obj
=
await
self
.
_tokenize_one_request
(
tmp_obj
)
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
generators
.
append
(
self
.
_wait_one_response
(
tmp_obj
,
request
))
rids
.
append
(
tmp_obj
.
rid
)
else
:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
# Tokenize all requests
objs
=
[
obj
[
i
]
for
i
in
range
(
batch_size
)]
tokenized_objs
=
await
asyncio
.
gather
(
*
(
self
.
_tokenize_one_request
(
obj
)
for
obj
in
objs
))
# Cache the common prefix for parallel sampling
for
i
in
range
(
batch_size
):
tmp_obj
=
copy
.
copy
(
objs
[
i
])
tokenized_obj
=
copy
.
copy
(
tokenized_objs
[
i
])
tokenized_obj
.
rid
=
tmp_obj
.
regenerate_rid
()
tokenized_obj
.
sampling_params
=
copy
.
copy
(
tokenized_obj
.
sampling_params
)
tokenized_obj
.
sampling_params
.
max_new_tokens
=
0
tokenized_obj
.
stream
=
False
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
await
self
.
_wait_one_response
(
tmp_obj
,
request
).
__anext__
()
# Expand requests, assign new rids for them, and send them
for
i
in
range
(
batch_size
):
for
_
in
range
(
obj
.
parallel_sample_num
):
tmp_obj
=
copy
.
copy
(
objs
[
i
])
tokenized_obj
=
copy
.
copy
(
tokenized_objs
[
i
])
tokenized_obj
.
rid
=
tmp_obj
.
regenerate_rid
()
self
.
send_to_scheduler
.
send_pyobj
(
tokenized_obj
)
generators
.
append
(
self
.
_wait_one_response
(
tmp_obj
,
request
))
rids
.
append
(
tmp_obj
.
rid
)
# Wait for all requests
is_stream
=
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
if
not
is_stream
:
outputs
=
await
asyncio
.
gather
(
*
(
gen
.
__anext__
()
for
gen
in
generators
))
yield
outputs
else
:
rid_to_index
=
{
rid
:
i
for
i
,
rid
in
enumerate
(
rids
)}
task_map
=
{
asyncio
.
create_task
(
gen
.
__anext__
()):
gen
for
gen
in
generators
}
while
task_map
:
done
,
_
=
await
asyncio
.
wait
(
task_map
.
keys
(),
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
gen
=
task_map
.
pop
(
task
)
try
:
result
=
task
.
result
()
result
[
"index"
]
=
rid_to_index
[
result
[
"meta_info"
][
"id"
]]
yield
result
new_task
=
asyncio
.
create_task
(
gen
.
__anext__
())
task_map
[
new_task
]
=
gen
except
StopAsyncIteration
:
pass
def
flush_cache
(
self
):
req
=
FlushCacheReq
()
self
.
send_to_scheduler
.
send_pyobj
(
req
)
...
...
python/sglang/srt/openai_api/adapter.py
View file @
c17c5781
...
...
@@ -71,6 +71,7 @@ from sglang.srt.openai_api.protocol import (
TopLogprob
,
UsageInfo
,
)
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -314,6 +315,8 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
)
except
Exception
as
e
:
logger
.
error
(
f
"error:
{
get_exception_traceback
()
}
"
)
responses
=
[]
error_json
=
{
"id"
:
f
"batch_req_
{
uuid
.
uuid4
()
}
"
,
"custom_id"
:
request_data
.
get
(
"custom_id"
),
...
...
@@ -363,7 +366,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
}
except
Exception
as
e
:
logger
.
error
(
"error
in SGLang:"
,
e
)
logger
.
error
(
f
"error
:
{
e
}
"
)
# Update batch status to "failed"
retrieve_batch
=
batch_storage
[
batch_id
]
retrieve_batch
.
status
=
"failed"
...
...
@@ -469,80 +472,67 @@ async def v1_retrieve_file_content(file_id: str):
def
v1_generate_request
(
all_requests
:
List
[
CompletionRequest
],
request_ids
:
List
[
str
]
=
None
):
if
len
(
all_requests
)
>
1
:
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
assert
(
type
(
request
.
prompt
)
is
first_prompt_type
),
"All prompts must be of the same type in file input settings"
if
request
.
n
>
1
:
raise
ValueError
(
"Parallel sampling is not supported for completions from files"
)
prompts
=
[]
sampling_params_list
=
[]
return_logprobs
=
[]
logprob_start_lens
=
[]
top_logprobs_nums
=
[]
# NOTE: with openai API, the prompt's logprobs are always not computed
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
assert
(
type
(
request
.
prompt
)
is
first_prompt_type
),
"All prompts must be of the same type in file input settings"
if
len
(
all_requests
)
>
1
and
request
.
n
>
1
:
raise
ValueError
(
"Parallel sampling is not supported for completions from files"
)
# NOTE: with openai API, the prompt's logprobs are always not computed
if
request
.
echo
and
request
.
logprobs
:
logger
.
warning
(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use
SGLang /request
API."
"To compute logprobs of input prompt, please use
the native /generate
API."
)
for
request
in
all_requests
:
prompts
.
append
(
request
.
prompt
)
sampling_params_list
.
append
(
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"min_new_tokens"
:
request
.
min_tokens
,
"stop"
:
request
.
stop
,
"stop_token_ids"
:
request
.
stop_token_ids
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
"no_stop_trim"
:
request
.
no_stop_trim
,
}
)
return_logprobs
.
append
(
request
.
logprobs
is
not
None
and
request
.
logprobs
>
0
)
logprob_start_lens
.
append
(
-
1
)
top_logprobs_nums
.
append
(
request
.
logprobs
if
request
.
logprobs
is
not
None
else
0
)
sampling_params
=
[]
if
isinstance
(
request
.
no_stop_trim
,
list
):
num_reqs
=
len
(
request
.
prompt
)
else
:
num_reqs
=
1
for
i
in
range
(
num_reqs
):
sampling_params
.
append
(
{
"temperature"
:
request
.
temperature
,
"max_new_tokens"
:
request
.
max_tokens
,
"min_new_tokens"
:
request
.
min_tokens
,
"stop"
:
request
.
stop
,
"stop_token_ids"
:
request
.
stop_token_ids
,
"top_p"
:
request
.
top_p
,
"presence_penalty"
:
request
.
presence_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
"no_stop_trim"
:
(
request
.
no_stop_trim
if
not
isinstance
(
request
.
no_stop_trim
,
list
)
else
request
.
no_stop_trim
[
i
]
),
}
)
if
num_reqs
==
1
:
sampling_params_list
.
append
(
sampling_params
[
0
])
else
:
sampling_params_list
.
append
(
sampling_params
)
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
prompts
[
0
][
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
[
0
]}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
[
0
]}
sampling_params_list
=
sampling_params_list
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
return_logprobs
=
return_logprobs
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
else
:
if
isinstance
(
prompts
[
0
],
str
):
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
prompts
[
0
][
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
...
...
@@ -558,9 +548,7 @@ def v1_generate_request(
rid
=
request_ids
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
return
adapted_request
,
all_requests
if
len
(
all_requests
)
>
1
else
all_requests
[
0
]
def
v1_generate_response
(
request
,
ret
,
tokenizer_manager
,
to_file
=
False
):
...
...
@@ -595,7 +583,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
if
isinstance
(
request
,
list
)
and
request
[
idx
].
echo
:
echo
=
True
text
=
request
[
idx
].
prompt
+
text
if
(
not
isinstance
(
request
,
list
)
)
and
echo
:
if
echo
and
not
isinstance
(
request
,
list
):
prompt_index
=
idx
//
request
.
n
text
=
prompts
[
prompt_index
]
+
text
...
...
@@ -709,7 +697,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
index
=
content
[
"index"
]
index
=
content
.
get
(
"index"
,
0
)
stream_buffer
=
stream_buffers
.
get
(
index
,
""
)
n_prev_token
=
n_prev_tokens
.
get
(
index
,
0
)
...
...
@@ -945,19 +933,18 @@ def v1_chat_generate_request(
sampling_params_list
.
append
(
sampling_params
)
image_data_list
.
append
(
image_data
)
modalities_list
.
ext
end
(
modalities
)
modalities_list
.
app
end
(
modalities
)
if
len
(
all_requests
)
==
1
:
input_ids
=
input_ids
[
0
]
if
isinstance
(
input_ids
,
str
):
prompt_kwargs
=
{
"text"
:
input_ids
}
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
[
0
]}
else
:
prompt_kwargs
=
{
"input_ids"
:
input_ids
}
prompt_kwargs
=
{
"input_ids"
:
input_ids
[
0
]
}
sampling_params_list
=
sampling_params_list
[
0
]
image_data_list
=
image_data_list
[
0
]
return_logprobs
=
return_logprobs
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
modalities_list
=
modalities_list
[
:
1
]
modalities_list
=
modalities_list
[
0
]
else
:
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
}
...
...
@@ -976,9 +963,8 @@ def v1_chat_generate_request(
rid
=
request_ids
,
modalities
=
modalities_list
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
return
adapted_request
,
all_requests
if
len
(
all_requests
)
>
1
else
all_requests
[
0
]
def
v1_chat_generate_response
(
request
,
ret
,
to_file
=
False
,
cache_report
=
False
):
...
...
@@ -1116,7 +1102,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
async
for
content
in
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
):
index
=
content
[
"index"
]
index
=
content
.
get
(
"index"
,
0
)
is_first
=
is_firsts
.
get
(
index
,
True
)
stream_buffer
=
stream_buffers
.
get
(
index
,
""
)
...
...
python/sglang/srt/server.py
View file @
c17c5781
...
...
@@ -53,7 +53,6 @@ from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from
sglang.srt.managers.io_struct
import
(
EmbeddingReqInput
,
GenerateReqInput
,
RewardReqInput
,
UpdateWeightReqInput
,
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
...
...
@@ -91,7 +90,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app
=
FastAPI
()
tokenizer_manager
=
None
tokenizer_manager
:
TokenizerManager
=
None
app
.
add_middleware
(
CORSMiddleware
,
...
...
@@ -254,7 +253,7 @@ app.post("/encode")(encode_request)
app
.
put
(
"/encode"
)(
encode_request
)
async
def
judge_request
(
obj
:
Reward
ReqInput
,
request
:
Request
):
async
def
judge_request
(
obj
:
Embedding
ReqInput
,
request
:
Request
):
"""Handle a reward model request."""
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
obj
,
request
).
__anext__
()
...
...
test/srt/run_suite.py
View file @
c17c5781
...
...
@@ -8,7 +8,7 @@ suites = {
"models/test_embedding_models.py"
,
"models/test_generation_models.py"
,
"models/test_lora.py"
,
"models/test_reward_models.py"
,
#
"models/test_reward_models.py",
"sampling/penaltylib"
,
"test_chunked_prefill.py"
,
"test_double_sparsity.py"
,
...
...
test/srt/test_openai_server.py
View file @
c17c5781
"""
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
"""
import
json
import
time
import
unittest
...
...
test/srt/test_skip_tokenizer_init.py
View file @
c17c5781
"""
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
"""
import
json
import
unittest
...
...
test/srt/test_srt_endpoint.py
View file @
c17c5781
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_parallel_sample
"""
import
json
...
...
@@ -36,11 +37,17 @@ class TestSRTEndpoint(unittest.TestCase):
return_text
=
False
,
n
=
1
,
stream
=
False
,
batch
=
False
,
):
if
batch
:
text
=
[
"The capital of France is"
]
else
:
text
=
"The capital of France is"
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"text"
:
text
,
"sampling_params"
:
{
"temperature"
:
0
if
n
==
1
else
0.5
,
"max_new_tokens"
:
16
,
...
...
@@ -67,6 +74,9 @@ class TestSRTEndpoint(unittest.TestCase):
def
test_simple_decode
(
self
):
self
.
run_decode
()
def
test_simple_decode_batch
(
self
):
self
.
run_decode
(
batch
=
True
)
def
test_parallel_sample
(
self
):
self
.
run_decode
(
n
=
3
)
...
...
test/srt/test_vision_openai_server.py
View file @
c17c5781
"""
Usage:
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_multi_images_chat_completion
"""
import
base64
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment