Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
838dcda1
Unverified
Commit
838dcda1
authored
Nov 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 03, 2024
Browse files
Simplify tokenizer manager (#1899)
parent
efbc116a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
49 deletions
+50
-49
docs/references/custom_chat_template.md
docs/references/custom_chat_template.md
+11
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+12
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+27
-43
No files found.
docs/references/custom_chat_template.md
View file @
838dcda1
...
@@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server:
...
@@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
```
```
If the chat template you are looking for is missing, you are welcome to contribute it.
If the chat template you are looking for is missing, you are welcome to contribute it or load it from a file.
Meanwhile, you can also temporarily register your chat template as follows:
## JSON Format
You can load the JSON format, which is defined by
`conversation.py`
.
```
json
```
json
{
{
...
@@ -29,3 +31,10 @@ Meanwhile, you can also temporarily register your chat template as follows:
...
@@ -29,3 +31,10 @@ Meanwhile, you can also temporarily register your chat template as follows:
```
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
```
```
## Jinja Format
You can also use the Jinja template format, defined by Hugging Face transformers https://huggingface.co/docs/transformers/main/en/chat_templating
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.jinja
```
\ No newline at end of file
python/sglang/srt/managers/io_struct.py
View file @
838dcda1
...
@@ -114,8 +114,7 @@ class GenerateReqInput:
...
@@ -114,8 +114,7 @@ class GenerateReqInput:
if
self
.
parallel_sample_num
==
1
:
if
self
.
parallel_sample_num
==
1
:
num
=
self
.
batch_size
num
=
self
.
batch_size
else
:
else
:
# FIXME support cascade inference
# The first bs samples are used for caching the prefix for parallel sampling
# first bs samples are used for caching the prefix for parallel sampling
num
=
self
.
batch_size
+
self
.
parallel_sample_num
*
self
.
batch_size
num
=
self
.
batch_size
+
self
.
parallel_sample_num
*
self
.
batch_size
if
self
.
image_data
is
None
:
if
self
.
image_data
is
None
:
...
@@ -196,6 +195,9 @@ class EmbeddingReqInput:
...
@@ -196,6 +195,9 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
@@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput:
...
@@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput:
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
RewardReqConv
=
Union
[
List
[
List
[
Dict
]],
List
[
Dict
],
str
,
List
[
str
]]
@
dataclass
@
dataclass
class
RewardReqInput
:
class
RewardReqInput
:
# The input prompt
in the chat format
. It can be a single prompt or a batch of prompts.
# The input prompt. It can be a single prompt or a batch of prompts.
Can be either chat format or a string.
conv
:
Union
[
List
[
List
[
Dict
]],
List
[
Dict
]]
conv
:
RewardReqConv
# The request id.
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
def
post_init
(
self
):
self
.
is_single
=
isinstance
(
self
.
conv
[
0
],
dict
)
self
.
is_single
=
isinstance
(
self
.
conv
[
0
],
dict
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
838dcda1
...
@@ -51,6 +51,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -51,6 +51,7 @@ from sglang.srt.managers.io_struct import (
GetMemPoolSizeReq
,
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
GetMemPoolSizeReqOutput
,
ProfileReq
,
ProfileReq
,
RewardReqConv
,
RewardReqInput
,
RewardReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
...
@@ -89,6 +90,7 @@ class TokenizerManager:
...
@@ -89,6 +90,7 @@ class TokenizerManager:
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
):
):
# Parse args
self
.
server_args
=
server_args
self
.
server_args
=
server_args
# Init inter-process communication
# Init inter-process communication
...
@@ -114,6 +116,7 @@ class TokenizerManager:
...
@@ -114,6 +116,7 @@ class TokenizerManager:
self
.
context_len
=
server_args
.
context_length
or
get_context_length
(
self
.
context_len
=
server_args
.
context_length
or
get_context_length
(
self
.
hf_config
self
.
hf_config
)
)
# Create image processor placeholder
# Create image processor placeholder
self
.
image_processor
=
get_dummy_image_processor
()
self
.
image_processor
=
get_dummy_image_processor
()
...
@@ -165,7 +168,8 @@ class TokenizerManager:
...
@@ -165,7 +168,8 @@ class TokenizerManager:
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
raise
ValueError
(
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
)
obj
.
post_init
()
obj
.
post_init
()
...
@@ -187,12 +191,8 @@ class TokenizerManager:
...
@@ -187,12 +191,8 @@ class TokenizerManager:
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
if
index
is
None
:
if
index
is
None
:
rid
=
obj
.
rid
rid
=
obj
.
rid
if
hasattr
(
obj
,
"conv"
):
if
isinstance
(
obj
,
RewardReqInput
):
# reward model
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
)
conv
=
obj
.
conv
input_text
=
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
input_text
=
obj
.
text
...
@@ -213,12 +213,8 @@ class TokenizerManager:
...
@@ -213,12 +213,8 @@ class TokenizerManager:
top_logprobs_num
=
obj
.
top_logprobs_num
top_logprobs_num
=
obj
.
top_logprobs_num
else
:
else
:
rid
=
obj
.
rid
[
index
]
rid
=
obj
.
rid
[
index
]
if
hasattr
(
obj
,
"conv"
):
if
isinstance
(
obj
,
RewardReqInput
):
# reward model
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
[
input_id_index
])
conv
=
obj
.
conv
[
index
]
input_text
=
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
input_id_index
]
input_text
=
obj
.
text
[
input_id_index
]
...
@@ -349,8 +345,9 @@ class TokenizerManager:
...
@@ -349,8 +345,9 @@ class TokenizerManager:
async
for
response
in
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
):
async
for
response
in
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
):
yield
response
yield
response
else
:
else
:
assert
self
.
is_generation
await
state
.
event
.
wait
()
await
self
.
_wait_for_cache_prefill_response
(
state
,
obj
,
rid
,
request
)
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
yield
input_ids
yield
input_ids
async
def
_handle_batch_request
(
async
def
_handle_batch_request
(
...
@@ -456,6 +453,15 @@ class TokenizerManager:
...
@@ -456,6 +453,15 @@ class TokenizerManager:
sampling_params
.
verify
()
sampling_params
.
verify
()
return
sampling_params
return
sampling_params
def
_apply_chat_template
(
self
,
conv
:
RewardReqConv
)
->
Union
[
str
,
List
[
str
]]:
if
isinstance
(
conv
,
str
):
return
conv
elif
isinstance
(
conv
,
list
):
if
isinstance
(
conv
[
0
],
str
):
return
conv
else
:
return
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
async
def
_wait_for_response
(
async
def
_wait_for_response
(
self
,
self
,
state
:
ReqState
,
state
:
ReqState
,
...
@@ -491,12 +497,11 @@ class TokenizerManager:
...
@@ -491,12 +497,11 @@ class TokenizerManager:
out
[
"index"
]
=
response_index
out
[
"index"
]
=
response_index
# Log requests
if
self
.
server_args
.
log_requests
and
state
.
finished
:
logger
.
info
(
f
"in=
{
obj
}
, out=
{
out
}
"
)
state
.
out_list
=
[]
state
.
out_list
=
[]
if
state
.
finished
:
if
state
.
finished
:
# Log requests
if
self
.
server_args
.
log_requests
:
logger
.
info
(
f
"in=
{
obj
}
, out=
{
out
}
"
)
del
self
.
rid_to_state
[
rid
]
del
self
.
rid_to_state
[
rid
]
yield
out
yield
out
break
break
...
@@ -504,27 +509,6 @@ class TokenizerManager:
...
@@ -504,27 +509,6 @@ class TokenizerManager:
state
.
event
.
clear
()
state
.
event
.
clear
()
yield
out
yield
out
async
def
_wait_for_cache_prefill_response
(
self
,
state
:
ReqState
,
obj
:
GenerateReqInput
,
rid
:
str
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
break
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
def
flush_cache
(
self
):
def
flush_cache
(
self
):
req
=
FlushCacheReq
()
req
=
FlushCacheReq
()
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
...
@@ -553,6 +537,7 @@ class TokenizerManager:
...
@@ -553,6 +537,7 @@ class TokenizerManager:
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
mem_pool_size
=
asyncio
.
Future
()
self
.
mem_pool_size
=
asyncio
.
Future
()
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
if
self
.
server_args
.
dp_size
==
1
:
if
self
.
server_args
.
dp_size
==
1
:
res
=
await
self
.
mem_pool_size
res
=
await
self
.
mem_pool_size
return
res
.
size
return
res
.
size
...
@@ -638,7 +623,7 @@ class TokenizerManager:
...
@@ -638,7 +623,7 @@ class TokenizerManager:
while
True
:
while
True
:
remain_num_req
=
len
(
self
.
rid_to_state
)
remain_num_req
=
len
(
self
.
rid_to_state
)
logger
.
info
(
logger
.
info
(
f
"
g
racefully exiting... remaining number of requests
{
remain_num_req
}
"
f
"
G
racefully exiting... remaining number of requests
{
remain_num_req
}
"
)
)
if
remain_num_req
>
0
:
if
remain_num_req
>
0
:
await
asyncio
.
sleep
(
5
)
await
asyncio
.
sleep
(
5
)
...
@@ -695,7 +680,6 @@ class TokenizerManager:
...
@@ -695,7 +680,6 @@ class TokenizerManager:
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
}
else
:
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
out_dict
=
{
...
@@ -747,7 +731,7 @@ class TokenizerManager:
...
@@ -747,7 +731,7 @@ class TokenizerManager:
token_texts
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
token_texts
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
return
[
return
[
(
logprob
,
token_id
,
token_text
)
(
logprob
,
token_id
,
token_text
)
for
(
logprob
,
token_id
),
token_text
,
in
zip
(
token_logprobs
,
token_texts
)
for
(
logprob
,
token_id
),
token_text
in
zip
(
token_logprobs
,
token_texts
)
]
]
def
detokenize_top_logprobs_tokens
(
self
,
top_logprobs
,
decode_to_text
:
bool
):
def
detokenize_top_logprobs_tokens
(
self
,
top_logprobs
,
decode_to_text
:
bool
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment