Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
838dcda1
Unverified
Commit
838dcda1
authored
Nov 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 03, 2024
Browse files
Simplify tokenizer manager (#1899)
parent
efbc116a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
49 deletions
+50
-49
docs/references/custom_chat_template.md
docs/references/custom_chat_template.md
+11
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+12
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+27
-43
No files found.
docs/references/custom_chat_template.md
View file @
838dcda1
...
...
@@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
```
If the chat template you are looking for is missing, you are welcome to contribute it.
Meanwhile, you can also temporarily register your chat template as follows:
If the chat template you are looking for is missing, you are welcome to contribute it or load it from a file.
## JSON Format
You can load the JSON format, which is defined by
`conversation.py`
.
```
json
{
...
...
@@ -28,4 +30,11 @@ Meanwhile, you can also temporarily register your chat template as follows:
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
```
## Jinja Format
You can also use the Jinja template format, defined by Hugging Face transformers https://huggingface.co/docs/transformers/main/en/chat_templating
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.jinja
```
\ No newline at end of file
python/sglang/srt/managers/io_struct.py
View file @
838dcda1
...
...
@@ -114,8 +114,7 @@ class GenerateReqInput:
if
self
.
parallel_sample_num
==
1
:
num
=
self
.
batch_size
else
:
# FIXME support cascade inference
# first bs samples are used for caching the prefix for parallel sampling
# The first bs samples are used for caching the prefix for parallel sampling
num
=
self
.
batch_size
+
self
.
parallel_sample_num
*
self
.
batch_size
if
self
.
image_data
is
None
:
...
...
@@ -196,6 +195,9 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
...
@@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput:
sampling_params
:
SamplingParams
RewardReqConv
=
Union
[
List
[
List
[
Dict
]],
List
[
Dict
],
str
,
List
[
str
]]
@
dataclass
class
RewardReqInput
:
# The input prompt
in the chat format
. It can be a single prompt or a batch of prompts.
conv
:
Union
[
List
[
List
[
Dict
]],
List
[
Dict
]]
# The input prompt. It can be a single prompt or a batch of prompts.
Can be either chat format or a string.
conv
:
RewardReqConv
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
self
.
is_single
=
isinstance
(
self
.
conv
[
0
],
dict
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
838dcda1
...
...
@@ -51,6 +51,7 @@ from sglang.srt.managers.io_struct import (
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
ProfileReq
,
RewardReqConv
,
RewardReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
...
...
@@ -89,6 +90,7 @@ class TokenizerManager:
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
):
# Parse args
self
.
server_args
=
server_args
# Init inter-process communication
...
...
@@ -114,6 +116,7 @@ class TokenizerManager:
self
.
context_len
=
server_args
.
context_length
or
get_context_length
(
self
.
hf_config
)
# Create image processor placeholder
self
.
image_processor
=
get_dummy_image_processor
()
...
...
@@ -165,7 +168,8 @@ class TokenizerManager:
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj
.
post_init
()
...
...
@@ -187,12 +191,8 @@ class TokenizerManager:
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
if
index
is
None
:
rid
=
obj
.
rid
if
hasattr
(
obj
,
"conv"
):
# reward model
conv
=
obj
.
conv
input_text
=
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
if
isinstance
(
obj
,
RewardReqInput
):
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
...
...
@@ -213,12 +213,8 @@ class TokenizerManager:
top_logprobs_num
=
obj
.
top_logprobs_num
else
:
rid
=
obj
.
rid
[
index
]
if
hasattr
(
obj
,
"conv"
):
# reward model
conv
=
obj
.
conv
[
index
]
input_text
=
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
if
isinstance
(
obj
,
RewardReqInput
):
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
[
input_id_index
])
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
input_id_index
]
...
...
@@ -349,8 +345,9 @@ class TokenizerManager:
async
for
response
in
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
):
yield
response
else
:
assert
self
.
is_generation
await
self
.
_wait_for_cache_prefill_response
(
state
,
obj
,
rid
,
request
)
await
state
.
event
.
wait
()
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
yield
input_ids
async
def
_handle_batch_request
(
...
...
@@ -456,6 +453,15 @@ class TokenizerManager:
sampling_params
.
verify
()
return
sampling_params
def
_apply_chat_template
(
self
,
conv
:
RewardReqConv
)
->
Union
[
str
,
List
[
str
]]:
if
isinstance
(
conv
,
str
):
return
conv
elif
isinstance
(
conv
,
list
):
if
isinstance
(
conv
[
0
],
str
):
return
conv
else
:
return
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
async
def
_wait_for_response
(
self
,
state
:
ReqState
,
...
...
@@ -491,12 +497,11 @@ class TokenizerManager:
out
[
"index"
]
=
response_index
# Log requests
if
self
.
server_args
.
log_requests
and
state
.
finished
:
logger
.
info
(
f
"in=
{
obj
}
, out=
{
out
}
"
)
state
.
out_list
=
[]
if
state
.
finished
:
# Log requests
if
self
.
server_args
.
log_requests
:
logger
.
info
(
f
"in=
{
obj
}
, out=
{
out
}
"
)
del
self
.
rid_to_state
[
rid
]
yield
out
break
...
...
@@ -504,27 +509,6 @@ class TokenizerManager:
state
.
event
.
clear
()
yield
out
async
def
_wait_for_cache_prefill_response
(
self
,
state
:
ReqState
,
obj
:
GenerateReqInput
,
rid
:
str
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
break
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
def
flush_cache
(
self
):
req
=
FlushCacheReq
()
self
.
send_to_scheduler
.
send_pyobj
(
req
)
...
...
@@ -553,6 +537,7 @@ class TokenizerManager:
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
mem_pool_size
=
asyncio
.
Future
()
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
if
self
.
server_args
.
dp_size
==
1
:
res
=
await
self
.
mem_pool_size
return
res
.
size
...
...
@@ -638,7 +623,7 @@ class TokenizerManager:
while
True
:
remain_num_req
=
len
(
self
.
rid_to_state
)
logger
.
info
(
f
"
g
racefully exiting... remaining number of requests
{
remain_num_req
}
"
f
"
G
racefully exiting... remaining number of requests
{
remain_num_req
}
"
)
if
remain_num_req
>
0
:
await
asyncio
.
sleep
(
5
)
...
...
@@ -695,7 +680,6 @@ class TokenizerManager:
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
...
...
@@ -747,7 +731,7 @@ class TokenizerManager:
token_texts
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
return
[
(
logprob
,
token_id
,
token_text
)
for
(
logprob
,
token_id
),
token_text
,
in
zip
(
token_logprobs
,
token_texts
)
for
(
logprob
,
token_id
),
token_text
in
zip
(
token_logprobs
,
token_texts
)
]
def
detokenize_top_logprobs_tokens
(
self
,
top_logprobs
,
decode_to_text
:
bool
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment