Unverified Commit f70f7258 authored by Qubitium's avatar Qubitium Committed by GitHub
Browse files

Fix rid state map leak + Refractor .finished (#505)


Co-authored-by: default avatarZX <zx@lbx.dev>
parent c0ae70c8
...@@ -10,6 +10,7 @@ import zmq ...@@ -10,6 +10,7 @@ import zmq
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -44,6 +45,8 @@ class DataParallelWorkerThread(threading.Thread): ...@@ -44,6 +45,8 @@ class DataParallelWorkerThread(threading.Thread):
requests = [] requests = []
while not self.request_queue.empty(): while not self.request_queue.empty():
requests.append(self.request_queue.get()) requests.append(self.request_queue.get())
out_pyobjs: List[BatchTokenIDOut] = []
try: try:
out_pyobjs = await self.step(requests) out_pyobjs = await self.step(requests)
except Exception: except Exception:
...@@ -61,7 +64,7 @@ class DataParallelWorkerThread(threading.Thread): ...@@ -61,7 +64,7 @@ class DataParallelWorkerThread(threading.Thread):
# async sleep for receiving the subsequent request and avoiding cache miss # async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs]) has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
if has_finished: if has_finished:
await asyncio.sleep(self.request_dependency_delay) await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay) await asyncio.sleep(global_config.wait_for_new_request_delay)
......
...@@ -15,25 +15,47 @@ class ForwardMode(IntEnum): ...@@ -15,25 +15,47 @@ class ForwardMode(IntEnum):
EXTEND = auto() EXTEND = auto()
DECODE = auto() DECODE = auto()
class BaseFinishReason:
def __init__(self, is_error: bool = False):
self.is_error = is_error
class FinishReason(IntEnum): def __str__(self):
EOS_TOKEN = auto() raise NotImplementedError("Subclasses must implement this method")
LENGTH = auto()
STOP_STR = auto()
ABORT = auto() class FINISH_MATCHED_TOKEN(BaseFinishReason):
def __init__(self, matched: int | List[int]):
@staticmethod super().__init__()
def to_str(reason): self.matched = matched
if reason == FinishReason.EOS_TOKEN:
return None def __str__(self) -> str:
elif reason == FinishReason.LENGTH: return f"FINISH_MATCHED_TOKEN: {self.matched}"
return "length"
elif reason == FinishReason.STOP_STR:
return "stop" class FINISH_LENGTH(BaseFinishReason):
elif reason == FinishReason.ABORT: def __init__(self, length: int):
return "abort" super().__init__()
else: self.length = length
return None
def __str__(self) -> str:
return f"FINISH_LENGTH: {self.length}"
class FINISH_MATCHED_STR(BaseFinishReason):
def __init__(self, matched: str):
super().__init__()
self.matched = matched
def __str__(self) -> str:
return f"FINISH_MATCHED_STR: {self.matched}"
class FINISH_ABORT(BaseFinishReason):
def __init__(self):
super().__init__(is_error=True)
def __str__(self) -> str:
return "FINISH_ABORT"
class Req: class Req:
...@@ -61,11 +83,10 @@ class Req: ...@@ -61,11 +83,10 @@ class Req:
self.sampling_params = None self.sampling_params = None
self.stream = False self.stream = False
# Check finish
self.tokenizer = None self.tokenizer = None
self.finished = False
self.finish_reason = None # Check finish
self.hit_stop_str = None self.finished_reason = None
# Prefix info # Prefix info
self.extend_input_len = 0 self.extend_input_len = 0
...@@ -90,6 +111,10 @@ class Req: ...@@ -90,6 +111,10 @@ class Req:
self.regex_fsm_state = 0 self.regex_fsm_state = 0
self.jump_forward_map = None self.jump_forward_map = None
# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None
def partial_decode(self, ids): def partial_decode(self, ids):
first_token = self.tokenizer.convert_ids_to_tokens(ids[0]) first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
first_token = ( first_token = (
...@@ -101,23 +126,21 @@ class Req: ...@@ -101,23 +126,21 @@ class Req:
return self.sampling_params.max_new_tokens return self.sampling_params.max_new_tokens
def check_finished(self): def check_finished(self):
if self.finished: if self.finished():
return return
if ( if (
len(self.prev_output_ids) + len(self.output_ids) len(self.prev_output_ids) + len(self.output_ids)
>= self.sampling_params.max_new_tokens >= self.sampling_params.max_new_tokens
): ):
self.finished = True self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids))
self.finish_reason = FinishReason.LENGTH
return return
if ( if (
self.output_ids[-1] == self.tokenizer.eos_token_id self.output_ids[-1] == self.tokenizer.eos_token_id
and self.sampling_params.ignore_eos == False and not self.sampling_params.ignore_eos
): ):
self.finished = True self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id)
self.finish_reason = FinishReason.EOS_TOKEN
return return
if len(self.sampling_params.stop_strs) > 0: if len(self.sampling_params.stop_strs) > 0:
...@@ -128,9 +151,7 @@ class Req: ...@@ -128,9 +151,7 @@ class Req:
for stop_str in self.sampling_params.stop_strs: for stop_str in self.sampling_params.stop_strs:
# FIXME: (minor) try incremental match in prev_output_str # FIXME: (minor) try incremental match in prev_output_str
if stop_str in tail_str or stop_str in self.prev_output_str: if stop_str in tail_str or stop_str in self.prev_output_str:
self.finished = True self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
self.finish_reason = FinishReason.STOP_STR
self.hit_stop_str = stop_str
return return
def jump_forward_and_retokenize(self, jump_forward_str, next_state): def jump_forward_and_retokenize(self, jump_forward_str, next_state):
......
...@@ -45,7 +45,7 @@ class ControllerSingle: ...@@ -45,7 +45,7 @@ class ControllerSingle:
# async sleep for receiving the subsequent request and avoiding cache miss # async sleep for receiving the subsequent request and avoiding cache miss
slept = False slept = False
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs]) has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
if has_finished: if has_finished:
if self.request_dependency_delay > 0: if self.request_dependency_delay > 0:
slept = True slept = True
......
...@@ -19,7 +19,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -19,7 +19,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
...@@ -595,20 +595,19 @@ class ModelTpServer: ...@@ -595,20 +595,19 @@ class ModelTpServer:
output_rids = [] output_rids = []
prev_output_strs = [] prev_output_strs = []
output_tokens = [] output_tokens = []
output_hit_stop_str = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_meta_info = [] output_meta_info = []
output_finished = [] output_finished_reason: List[BaseFinishReason] = []
finished_indices = [] finished_indices = []
unfinished_indices = [] unfinished_indices = []
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if req.finished: if req.finished():
finished_indices.append(i) finished_indices.append(i)
else: else:
unfinished_indices.append(i) unfinished_indices.append(i)
if req.finished or ( if req.finished() or (
( (
req.stream req.stream
and ( and (
...@@ -620,7 +619,6 @@ class ModelTpServer: ...@@ -620,7 +619,6 @@ class ModelTpServer:
output_rids.append(req.rid) output_rids.append(req.rid)
prev_output_strs.append(req.prev_output_str) prev_output_strs.append(req.prev_output_str)
output_tokens.append(req.output_ids) output_tokens.append(req.output_ids)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append( output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens req.sampling_params.skip_special_tokens
) )
...@@ -632,8 +630,7 @@ class ModelTpServer: ...@@ -632,8 +630,7 @@ class ModelTpServer:
"prompt_tokens": len(req.origin_input_ids), "prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids), "completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": FinishReason.to_str(req.finish_reason), "finish_reason": str(req.finished_reason),
"hit_stop_str": req.hit_stop_str,
} }
if req.return_logprob: if req.return_logprob:
( (
...@@ -650,7 +647,7 @@ class ModelTpServer: ...@@ -650,7 +647,7 @@ class ModelTpServer:
req.normalized_prompt_logprob, req.normalized_prompt_logprob,
) )
output_meta_info.append(meta_info) output_meta_info.append(meta_info)
output_finished.append(req.finished) output_finished_reason.append(req.finished_reason)
# Send to detokenizer # Send to detokenizer
if output_rids: if output_rids:
...@@ -659,11 +656,10 @@ class ModelTpServer: ...@@ -659,11 +656,10 @@ class ModelTpServer:
output_rids, output_rids,
prev_output_strs, prev_output_strs,
output_tokens, output_tokens,
output_hit_stop_str,
output_skip_special_tokens, output_skip_special_tokens,
output_spaces_between_special_tokens, output_spaces_between_special_tokens,
output_meta_info, output_meta_info,
output_finished, output_finished_reason,
) )
) )
...@@ -720,8 +716,7 @@ class ModelTpServer: ...@@ -720,8 +716,7 @@ class ModelTpServer:
if self.running_batch: if self.running_batch:
for req in self.running_batch.reqs: for req in self.running_batch.reqs:
if req.rid == recv_req.rid: if req.rid == recv_req.rid:
req.finished = True req.finished_reason = FINISH_ABORT()
req.finish_reason = FinishReason.ABORT
break break
......
...@@ -9,6 +9,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer ...@@ -9,6 +9,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback, graceful_registry from sglang.utils import get_exception_traceback, graceful_registry
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -34,49 +35,47 @@ class DetokenizerManager: ...@@ -34,49 +35,47 @@ class DetokenizerManager:
async def handle_loop(self): async def handle_loop(self):
while True: while True:
recv_obj = await self.recv_from_router.recv_pyobj() recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
assert isinstance(recv_obj, BatchTokenIDOut)
if isinstance(recv_obj, BatchTokenIDOut):
output_tokens = recv_obj.output_tokens output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs = self.tokenizer.batch_decode( output_strs = self.tokenizer.batch_decode(
output_tokens, output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0], skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0 0
], ],
) )
# Trim stop str # Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit # TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)): for i in range(len(output_strs)):
if len(output_tokens[i]) > 0: if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens( first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0]) int(output_tokens[i][0])
)
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8", errors="ignore")
if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
if recv_obj.hit_stop_str[i] is not None:
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
if pos != -1:
output_strs[i] = output_strs[i][:pos]
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
recv_obj.rids,
output_strs,
recv_obj.meta_info,
recv_obj.finished,
) )
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8", errors="ignore")
if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
if pos != -1:
output_strs[i] = output_strs[i][:pos]
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
rids=recv_obj.rids,
output_str=output_strs,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
) )
else: )
raise ValueError(f"Invalid object: {recv_obj}")
def start_detokenizer_process( def start_detokenizer_process(
......
...@@ -3,6 +3,7 @@ from dataclasses import dataclass ...@@ -3,6 +3,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
@dataclass @dataclass
...@@ -105,21 +106,19 @@ class TokenizedGenerateReqInput: ...@@ -105,21 +106,19 @@ class TokenizedGenerateReqInput:
@dataclass @dataclass
class BatchTokenIDOut: class BatchTokenIDOut:
rids: List[str] rids: List[str]
prev_output_strs : List[str] prev_output_strs: List[str]
output_tokens: List[List[int]] output_tokens: List[List[int]]
hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool] spaces_between_special_tokens: List[bool]
meta_info: List[Dict] meta_info: List[Dict]
finished: List[bool] finished_reason: List[BaseFinishReason]
@dataclass @dataclass
class BatchStrOut: class BatchStrOut:
rids: List[str] rids: List[str]
output_str: List[str] output_str: List[str]
meta_info: List[Dict] meta_info: List[Dict]
finished: List[bool] finished_reason: List[BaseFinishReason]
@dataclass @dataclass
...@@ -134,4 +133,4 @@ class AbortReq: ...@@ -134,4 +133,4 @@ class AbortReq:
@dataclass @dataclass
class DetokenizeReqInput: class DetokenizeReqInput:
input_ids: List[int] input_ids: List[int]
\ No newline at end of file
...@@ -4,7 +4,7 @@ import dataclasses ...@@ -4,7 +4,7 @@ import dataclasses
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
from typing import List from typing import List, Dict
import numpy as np import numpy as np
import transformers import transformers
...@@ -26,6 +26,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -26,6 +26,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput, GenerateReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -89,7 +90,7 @@ class TokenizerManager: ...@@ -89,7 +90,7 @@ class TokenizerManager:
) )
self.to_create_loop = True self.to_create_loop = True
self.rid_to_state = {} # Dict[str -> ReqState] self.rid_to_state: Dict[str, ReqState] = {}
async def get_pixel_values(self, image_data): async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
...@@ -183,12 +184,17 @@ class TokenizerManager: ...@@ -183,12 +184,17 @@ class TokenizerManager:
if self.server_args.log_requests and state.finished: if self.server_args.log_requests and state.finished:
logger.info(f"in={obj.text}, out={out}") logger.info(f"in={obj.text}, out={out}")
yield out
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
del self.rid_to_state[rid] del self.rid_to_state[rid]
yield out
break break
event.clear() event.clear()
yield out
else: else:
if obj.stream: if obj.stream:
raise ValueError("Do not support stream for batch mode.") raise ValueError("Do not support stream for batch mode.")
...@@ -298,24 +304,23 @@ class TokenizerManager: ...@@ -298,24 +304,23 @@ class TokenizerManager:
async def handle_loop(self): async def handle_loop(self):
while True: while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj() recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
assert isinstance(recv_obj, BatchStrOut)
if isinstance(recv_obj, BatchStrOut): for i, rid in enumerate(recv_obj.rids):
for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None)
state = self.rid_to_state.get(rid, None) if state is None:
if state is None: continue
continue
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_str[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_str[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished[i]
state.event.set()
else:
raise ValueError(f"Invalid object: {recv_obj}.")
def convert_logprob_style( def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment