Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
838dcda1
"docs/EN/vscode:/vscode.git/clone" did not exist on "ac520b50b903436b37f50a42caf9387a830eb50a"
Unverified
Commit
838dcda1
authored
Nov 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 03, 2024
Browse files
Simplify tokenizer manager (#1899)
parent
efbc116a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
49 deletions
+50
-49
docs/references/custom_chat_template.md
docs/references/custom_chat_template.md
+11
-2
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+12
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+27
-43
No files found.
docs/references/custom_chat_template.md
View file @
838dcda1
...
...
@@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
```
If the chat template you are looking for is missing, you are welcome to contribute it.
Meanwhile, you can also temporarily register your chat template as follows:
If the chat template you are looking for is missing, you are welcome to contribute it or load it from a file.
## JSON Format
You can load the JSON format, which is defined by
`conversation.py`
.
```
json
{
...
...
@@ -29,3 +31,10 @@ Meanwhile, you can also temporarily register your chat template as follows:
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
```
## Jinja Format
You can also use the Jinja template format, defined by Hugging Face transformers https://huggingface.co/docs/transformers/main/en/chat_templating
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.jinja
```
\ No newline at end of file
python/sglang/srt/managers/io_struct.py
View file @
838dcda1
...
...
@@ -114,8 +114,7 @@ class GenerateReqInput:
if
self
.
parallel_sample_num
==
1
:
num
=
self
.
batch_size
else
:
# FIXME support cascade inference
# first bs samples are used for caching the prefix for parallel sampling
# The first bs samples are used for caching the prefix for parallel sampling
num
=
self
.
batch_size
+
self
.
parallel_sample_num
*
self
.
batch_size
if
self
.
image_data
is
None
:
...
...
@@ -196,6 +195,9 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
...
@@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput:
sampling_params
:
SamplingParams
RewardReqConv
=
Union
[
List
[
List
[
Dict
]],
List
[
Dict
],
str
,
List
[
str
]]
@
dataclass
class
RewardReqInput
:
# The input prompt
in the chat format
. It can be a single prompt or a batch of prompts.
conv
:
Union
[
List
[
List
[
Dict
]],
List
[
Dict
]]
# The input prompt. It can be a single prompt or a batch of prompts.
Can be either chat format or a string.
conv
:
RewardReqConv
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Whether it is a single request or a batch request
is_single
:
bool
=
True
def
post_init
(
self
):
self
.
is_single
=
isinstance
(
self
.
conv
[
0
],
dict
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
838dcda1
...
...
@@ -51,6 +51,7 @@ from sglang.srt.managers.io_struct import (
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
ProfileReq
,
RewardReqConv
,
RewardReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
...
...
@@ -89,6 +90,7 @@ class TokenizerManager:
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
):
# Parse args
self
.
server_args
=
server_args
# Init inter-process communication
...
...
@@ -114,6 +116,7 @@ class TokenizerManager:
self
.
context_len
=
server_args
.
context_length
or
get_context_length
(
self
.
hf_config
)
# Create image processor placeholder
self
.
image_processor
=
get_dummy_image_processor
()
...
...
@@ -165,7 +168,8 @@ class TokenizerManager:
if
isinstance
(
obj
,
EmbeddingReqInput
)
and
self
.
is_generation
:
raise
ValueError
(
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj
.
post_init
()
...
...
@@ -187,12 +191,8 @@ class TokenizerManager:
if
not
is_cache_for_prefill
:
# The normal case with a single prompt
if
index
is
None
:
rid
=
obj
.
rid
if
hasattr
(
obj
,
"conv"
):
# reward model
conv
=
obj
.
conv
input_text
=
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
if
isinstance
(
obj
,
RewardReqInput
):
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
...
...
@@ -213,12 +213,8 @@ class TokenizerManager:
top_logprobs_num
=
obj
.
top_logprobs_num
else
:
rid
=
obj
.
rid
[
index
]
if
hasattr
(
obj
,
"conv"
):
# reward model
conv
=
obj
.
conv
[
index
]
input_text
=
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
if
isinstance
(
obj
,
RewardReqInput
):
input_text
=
self
.
_apply_chat_template
(
obj
.
conv
[
input_id_index
])
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
elif
obj
.
input_ids
is
None
:
input_text
=
obj
.
text
[
input_id_index
]
...
...
@@ -349,8 +345,9 @@ class TokenizerManager:
async
for
response
in
self
.
_wait_for_response
(
state
,
obj
,
rid
,
request
):
yield
response
else
:
assert
self
.
is_generation
await
self
.
_wait_for_cache_prefill_response
(
state
,
obj
,
rid
,
request
)
await
state
.
event
.
wait
()
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
yield
input_ids
async
def
_handle_batch_request
(
...
...
@@ -456,6 +453,15 @@ class TokenizerManager:
sampling_params
.
verify
()
return
sampling_params
def
_apply_chat_template
(
self
,
conv
:
RewardReqConv
)
->
Union
[
str
,
List
[
str
]]:
if
isinstance
(
conv
,
str
):
return
conv
elif
isinstance
(
conv
,
list
):
if
isinstance
(
conv
[
0
],
str
):
return
conv
else
:
return
self
.
tokenizer
.
apply_chat_template
(
conv
,
tokenize
=
False
)
async
def
_wait_for_response
(
self
,
state
:
ReqState
,
...
...
@@ -491,12 +497,11 @@ class TokenizerManager:
out
[
"index"
]
=
response_index
# Log requests
if
self
.
server_args
.
log_requests
and
state
.
finished
:
logger
.
info
(
f
"in=
{
obj
}
, out=
{
out
}
"
)
state
.
out_list
=
[]
if
state
.
finished
:
# Log requests
if
self
.
server_args
.
log_requests
:
logger
.
info
(
f
"in=
{
obj
}
, out=
{
out
}
"
)
del
self
.
rid_to_state
[
rid
]
yield
out
break
...
...
@@ -504,27 +509,6 @@ class TokenizerManager:
state
.
event
.
clear
()
yield
out
async
def
_wait_for_cache_prefill_response
(
self
,
state
:
ReqState
,
obj
:
GenerateReqInput
,
rid
:
str
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
while
True
:
try
:
await
asyncio
.
wait_for
(
state
.
event
.
wait
(),
timeout
=
4
)
break
except
asyncio
.
TimeoutError
:
if
request
is
not
None
and
await
request
.
is_disconnected
():
for
rid
in
obj
.
rid
:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
def
flush_cache
(
self
):
req
=
FlushCacheReq
()
self
.
send_to_scheduler
.
send_pyobj
(
req
)
...
...
@@ -553,6 +537,7 @@ class TokenizerManager:
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
mem_pool_size
=
asyncio
.
Future
()
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
if
self
.
server_args
.
dp_size
==
1
:
res
=
await
self
.
mem_pool_size
return
res
.
size
...
...
@@ -638,7 +623,7 @@ class TokenizerManager:
while
True
:
remain_num_req
=
len
(
self
.
rid_to_state
)
logger
.
info
(
f
"
g
racefully exiting... remaining number of requests
{
remain_num_req
}
"
f
"
G
racefully exiting... remaining number of requests
{
remain_num_req
}
"
)
if
remain_num_req
>
0
:
await
asyncio
.
sleep
(
5
)
...
...
@@ -695,7 +680,6 @@ class TokenizerManager:
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
...
...
@@ -747,7 +731,7 @@ class TokenizerManager:
token_texts
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
return
[
(
logprob
,
token_id
,
token_text
)
for
(
logprob
,
token_id
),
token_text
,
in
zip
(
token_logprobs
,
token_texts
)
for
(
logprob
,
token_id
),
token_text
in
zip
(
token_logprobs
,
token_texts
)
]
def
detokenize_top_logprobs_tokens
(
self
,
top_logprobs
,
decode_to_text
:
bool
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment