Unverified Commit 838dcda1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify tokenizer manager (#1899)

parent efbc116a
......@@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
```
If the chat template you are looking for is missing, you are welcome to contribute it.
Meanwhile, you can also temporarily register your chat template as follows:
If the chat template you are looking for is missing, you are welcome to contribute it or load it from a file.
## JSON Format
You can load the JSON format, which is defined by `conversation.py`.
```json
{
......@@ -28,4 +30,11 @@ Meanwhile, you can also temporarily register your chat template as follows:
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
```
## Jinja Format
You can also use the Jinja template format, defined by Hugging Face transformers https://huggingface.co/docs/transformers/main/en/chat_templating
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.jinja
```
\ No newline at end of file
......@@ -114,8 +114,7 @@ class GenerateReqInput:
if self.parallel_sample_num == 1:
num = self.batch_size
else:
# FIXME support cascade inference
# first bs samples are used for caching the prefix for parallel sampling
# The first bs samples are used for caching the prefix for parallel sampling
num = self.batch_size + self.parallel_sample_num * self.batch_size
if self.image_data is None:
......@@ -196,6 +195,9 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
......@@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams
RewardReqConv = Union[List[List[Dict]], List[Dict], str, List[str]]
@dataclass
class RewardReqInput:
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
conv: Union[List[List[Dict]], List[Dict]]
# The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string.
conv: RewardReqConv
# The request id.
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self):
self.is_single = isinstance(self.conv[0], dict)
......
......@@ -51,6 +51,7 @@ from sglang.srt.managers.io_struct import (
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
ProfileReq,
RewardReqConv,
RewardReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
......@@ -89,6 +90,7 @@ class TokenizerManager:
server_args: ServerArgs,
port_args: PortArgs,
):
# Parse args
self.server_args = server_args
# Init inter-process communication
......@@ -114,6 +116,7 @@ class TokenizerManager:
self.context_len = server_args.context_length or get_context_length(
self.hf_config
)
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
......@@ -165,7 +168,8 @@ class TokenizerManager:
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj.post_init()
......@@ -187,12 +191,8 @@ class TokenizerManager:
if not is_cache_for_prefill: # The normal case with a single prompt
if index is None:
rid = obj.rid
if hasattr(obj, "conv"):
# reward model
conv = obj.conv
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
if isinstance(obj, RewardReqInput):
input_text = self._apply_chat_template(obj.conv)
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text
......@@ -213,12 +213,8 @@ class TokenizerManager:
top_logprobs_num = obj.top_logprobs_num
else:
rid = obj.rid[index]
if hasattr(obj, "conv"):
# reward model
conv = obj.conv[index]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
if isinstance(obj, RewardReqInput):
input_text = self._apply_chat_template(obj.conv[input_id_index])
input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None:
input_text = obj.text[input_id_index]
......@@ -349,8 +345,9 @@ class TokenizerManager:
async for response in self._wait_for_response(state, obj, rid, request):
yield response
else:
assert self.is_generation
await self._wait_for_cache_prefill_response(state, obj, rid, request)
await state.event.wait()
assert state.finished
del self.rid_to_state[rid]
yield input_ids
async def _handle_batch_request(
......@@ -456,6 +453,15 @@ class TokenizerManager:
sampling_params.verify()
return sampling_params
def _apply_chat_template(self, conv: RewardReqConv) -> Union[str, List[str]]:
if isinstance(conv, str):
return conv
elif isinstance(conv, list):
if isinstance(conv[0], str):
return conv
else:
return self.tokenizer.apply_chat_template(conv, tokenize=False)
async def _wait_for_response(
self,
state: ReqState,
......@@ -491,12 +497,11 @@ class TokenizerManager:
out["index"] = response_index
# Log requests
if self.server_args.log_requests and state.finished:
logger.info(f"in={obj}, out={out}")
state.out_list = []
if state.finished:
# Log requests
if self.server_args.log_requests:
logger.info(f"in={obj}, out={out}")
del self.rid_to_state[rid]
yield out
break
......@@ -504,27 +509,6 @@ class TokenizerManager:
state.event.clear()
yield out
async def _wait_for_cache_prefill_response(
self,
state: ReqState,
obj: GenerateReqInput,
rid: str,
request: Optional[fastapi.Request] = None,
):
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
break
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
for rid in obj.rid:
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue
assert state.finished
del self.rid_to_state[rid]
def flush_cache(self):
req = FlushCacheReq()
self.send_to_scheduler.send_pyobj(req)
......@@ -553,6 +537,7 @@ class TokenizerManager:
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
if self.server_args.dp_size == 1:
res = await self.mem_pool_size
return res.size
......@@ -638,7 +623,7 @@ class TokenizerManager:
while True:
remain_num_req = len(self.rid_to_state)
logger.info(
f"gracefully exiting... remaining number of requests {remain_num_req}"
f"Gracefully exiting... remaining number of requests {remain_num_req}"
)
if remain_num_req > 0:
await asyncio.sleep(5)
......@@ -695,7 +680,6 @@ class TokenizerManager:
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
......@@ -747,7 +731,7 @@ class TokenizerManager:
token_texts = self.tokenizer.batch_decode(token_ids)
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment