Unverified Commit 838dcda1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify tokenizer manager (#1899)

parent efbc116a
...@@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server: ...@@ -11,8 +11,10 @@ If needed, you can also override the chat template when launching the server:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2 python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
``` ```
If the chat template you are looking for is missing, you are welcome to contribute it. If the chat template you are looking for is missing, you are welcome to contribute it or load it from a file.
Meanwhile, you can also temporarily register your chat template as follows:
## JSON Format
You can load the JSON format, which is defined by `conversation.py`.
```json ```json
{ {
...@@ -29,3 +31,10 @@ Meanwhile, you can also temporarily register your chat template as follows: ...@@ -29,3 +31,10 @@ Meanwhile, you can also temporarily register your chat template as follows:
``` ```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
``` ```
## Jinja Format
You can also use the Jinja template format, defined by Hugging Face transformers https://huggingface.co/docs/transformers/main/en/chat_templating
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.jinja
```
\ No newline at end of file
...@@ -114,8 +114,7 @@ class GenerateReqInput: ...@@ -114,8 +114,7 @@ class GenerateReqInput:
if self.parallel_sample_num == 1: if self.parallel_sample_num == 1:
num = self.batch_size num = self.batch_size
else: else:
# FIXME support cascade inference # The first bs samples are used for caching the prefix for parallel sampling
# first bs samples are used for caching the prefix for parallel sampling
num = self.batch_size + self.parallel_sample_num * self.batch_size num = self.batch_size + self.parallel_sample_num * self.batch_size
if self.image_data is None: if self.image_data is None:
...@@ -196,6 +195,9 @@ class EmbeddingReqInput: ...@@ -196,6 +195,9 @@ class EmbeddingReqInput:
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self): def post_init(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
...@@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput: ...@@ -241,15 +243,21 @@ class TokenizedEmbeddingReqInput:
sampling_params: SamplingParams sampling_params: SamplingParams
RewardReqConv = Union[List[List[Dict]], List[Dict], str, List[str]]
@dataclass @dataclass
class RewardReqInput: class RewardReqInput:
# The input prompt in the chat format. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts. Can be either chat format or a string.
conv: Union[List[List[Dict]], List[Dict]] conv: RewardReqConv
# The request id. # The request id.
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self): def post_init(self):
self.is_single = isinstance(self.conv[0], dict) self.is_single = isinstance(self.conv[0], dict)
......
...@@ -51,6 +51,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -51,6 +51,7 @@ from sglang.srt.managers.io_struct import (
GetMemPoolSizeReq, GetMemPoolSizeReq,
GetMemPoolSizeReqOutput, GetMemPoolSizeReqOutput,
ProfileReq, ProfileReq,
RewardReqConv,
RewardReqInput, RewardReqInput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -89,6 +90,7 @@ class TokenizerManager: ...@@ -89,6 +90,7 @@ class TokenizerManager:
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
): ):
# Parse args
self.server_args = server_args self.server_args = server_args
# Init inter-process communication # Init inter-process communication
...@@ -114,6 +116,7 @@ class TokenizerManager: ...@@ -114,6 +116,7 @@ class TokenizerManager:
self.context_len = server_args.context_length or get_context_length( self.context_len = server_args.context_length or get_context_length(
self.hf_config self.hf_config
) )
# Create image processor placeholder # Create image processor placeholder
self.image_processor = get_dummy_image_processor() self.image_processor = get_dummy_image_processor()
...@@ -165,7 +168,8 @@ class TokenizerManager: ...@@ -165,7 +168,8 @@ class TokenizerManager:
if isinstance(obj, EmbeddingReqInput) and self.is_generation: if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError( raise ValueError(
"This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model." "This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
) )
obj.post_init() obj.post_init()
...@@ -187,12 +191,8 @@ class TokenizerManager: ...@@ -187,12 +191,8 @@ class TokenizerManager:
if not is_cache_for_prefill: # The normal case with a single prompt if not is_cache_for_prefill: # The normal case with a single prompt
if index is None: if index is None:
rid = obj.rid rid = obj.rid
if hasattr(obj, "conv"): if isinstance(obj, RewardReqInput):
# reward model input_text = self._apply_chat_template(obj.conv)
conv = obj.conv
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text) input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None: elif obj.input_ids is None:
input_text = obj.text input_text = obj.text
...@@ -213,12 +213,8 @@ class TokenizerManager: ...@@ -213,12 +213,8 @@ class TokenizerManager:
top_logprobs_num = obj.top_logprobs_num top_logprobs_num = obj.top_logprobs_num
else: else:
rid = obj.rid[index] rid = obj.rid[index]
if hasattr(obj, "conv"): if isinstance(obj, RewardReqInput):
# reward model input_text = self._apply_chat_template(obj.conv[input_id_index])
conv = obj.conv[index]
input_text = self.tokenizer.apply_chat_template(
conv, tokenize=False
)
input_ids = self.tokenizer.encode(input_text) input_ids = self.tokenizer.encode(input_text)
elif obj.input_ids is None: elif obj.input_ids is None:
input_text = obj.text[input_id_index] input_text = obj.text[input_id_index]
...@@ -349,8 +345,9 @@ class TokenizerManager: ...@@ -349,8 +345,9 @@ class TokenizerManager:
async for response in self._wait_for_response(state, obj, rid, request): async for response in self._wait_for_response(state, obj, rid, request):
yield response yield response
else: else:
assert self.is_generation await state.event.wait()
await self._wait_for_cache_prefill_response(state, obj, rid, request) assert state.finished
del self.rid_to_state[rid]
yield input_ids yield input_ids
async def _handle_batch_request( async def _handle_batch_request(
...@@ -456,6 +453,15 @@ class TokenizerManager: ...@@ -456,6 +453,15 @@ class TokenizerManager:
sampling_params.verify() sampling_params.verify()
return sampling_params return sampling_params
def _apply_chat_template(self, conv: RewardReqConv) -> Union[str, List[str]]:
if isinstance(conv, str):
return conv
elif isinstance(conv, list):
if isinstance(conv[0], str):
return conv
else:
return self.tokenizer.apply_chat_template(conv, tokenize=False)
async def _wait_for_response( async def _wait_for_response(
self, self,
state: ReqState, state: ReqState,
...@@ -491,12 +497,11 @@ class TokenizerManager: ...@@ -491,12 +497,11 @@ class TokenizerManager:
out["index"] = response_index out["index"] = response_index
# Log requests
if self.server_args.log_requests and state.finished:
logger.info(f"in={obj}, out={out}")
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
# Log requests
if self.server_args.log_requests:
logger.info(f"in={obj}, out={out}")
del self.rid_to_state[rid] del self.rid_to_state[rid]
yield out yield out
break break
...@@ -504,27 +509,6 @@ class TokenizerManager: ...@@ -504,27 +509,6 @@ class TokenizerManager:
state.event.clear() state.event.clear()
yield out yield out
async def _wait_for_cache_prefill_response(
self,
state: ReqState,
obj: GenerateReqInput,
rid: str,
request: Optional[fastapi.Request] = None,
):
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
break
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
for rid in obj.rid:
self.abort_request(rid)
raise ValueError(f"Abort request {rid}")
continue
assert state.finished
del self.rid_to_state[rid]
def flush_cache(self): def flush_cache(self):
req = FlushCacheReq() req = FlushCacheReq()
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
...@@ -553,6 +537,7 @@ class TokenizerManager: ...@@ -553,6 +537,7 @@ class TokenizerManager:
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future() self.mem_pool_size = asyncio.Future()
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
res = await self.mem_pool_size res = await self.mem_pool_size
return res.size return res.size
...@@ -638,7 +623,7 @@ class TokenizerManager: ...@@ -638,7 +623,7 @@ class TokenizerManager:
while True: while True:
remain_num_req = len(self.rid_to_state) remain_num_req = len(self.rid_to_state)
logger.info( logger.info(
f"gracefully exiting... remaining number of requests {remain_num_req}" f"Gracefully exiting... remaining number of requests {remain_num_req}"
) )
if remain_num_req > 0: if remain_num_req > 0:
await asyncio.sleep(5) await asyncio.sleep(5)
...@@ -695,7 +680,6 @@ class TokenizerManager: ...@@ -695,7 +680,6 @@ class TokenizerManager:
"token_ids": recv_obj.output_ids[i], "token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i], "meta_info": recv_obj.meta_info[i],
} }
else: else:
assert isinstance(recv_obj, BatchEmbeddingOut) assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = { out_dict = {
...@@ -747,7 +731,7 @@ class TokenizerManager: ...@@ -747,7 +731,7 @@ class TokenizerManager:
token_texts = self.tokenizer.batch_decode(token_ids) token_texts = self.tokenizer.batch_decode(token_ids)
return [ return [
(logprob, token_id, token_text) (logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts) for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
] ]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment