Unverified Commit e0e09fce authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Session] Update session control interface (#2635)

parent 9c05c689
......@@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
@dataclass
class SessionParams:
id: Optional[str] = None
rid: Optional[str] = None
offset: Optional[int] = None
replace: Optional[bool] = None
@dataclass
class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
......@@ -58,10 +66,8 @@ class GenerateReqInput:
# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session id info for continual prompting
session: Optional[
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
] = None
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
def normalize_batch_and_arguments(self):
if (
......@@ -223,9 +229,8 @@ class TokenizedGenerateReqInput:
# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Session id info for continual prompting
session_id: Optional[str] = None
session_rid: Optional[str] = None
# Session info for continual prompting
session_params: Optional[SessionParams] = None
@dataclass
......@@ -468,6 +473,7 @@ class ProfileReq(Enum):
@dataclass
class OpenSessionReqInput:
capacity_of_str_len: int
session_id: Optional[str] = None
@dataclass
......@@ -477,4 +483,5 @@ class CloseSessionReqInput:
@dataclass
class OpenSessionReqOutput:
session_id: str
session_id: Optional[str]
success: bool
......@@ -22,7 +22,7 @@ import warnings
from collections import deque
from concurrent import futures
from types import SimpleNamespace
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
import psutil
import setproctitle
......@@ -498,8 +498,10 @@ class Scheduler:
else:
self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput):
session_id = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
session_id, success = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(
OpenSessionReqOutput(session_id=session_id, success=success)
)
elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req)
else:
......@@ -510,7 +512,11 @@ class Scheduler:
recv_req: TokenizedGenerateReqInput,
):
# Create a new request
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
if (
recv_req.session_params is None
or recv_req.session_params.id is None
or recv_req.session_params.id not in self.sessions
):
if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds
......@@ -532,15 +538,18 @@ class Scheduler:
)
req.tokenizer = self.tokenizer
if recv_req.session_id is not None:
if (
recv_req.session_params is not None
and recv_req.session_params.id is not None
):
req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_id} does not exist"
f"Invalid request: session id {recv_req.session_params.id} does not exist"
)
self.waiting_queue.append(req)
return
else:
# Create a new request from a previsou session
session = self.sessions[recv_req.session_id]
# Create a new request from a previous session
session = self.sessions[recv_req.session_params.id]
req = session.create_req(recv_req, self.tokenizer)
if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req)
......@@ -1500,16 +1509,20 @@ class Scheduler:
)
logger.info("Profiler is done")
def open_session(self, recv_req: OpenSessionReqInput) -> str:
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
return session_id, False
elif session_id is None:
logger.warning(f"session id is None, cannot open.")
return session_id, False
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
return session_id
return session_id, True
def close_session(self, recv_req: CloseSessionReqInput):
# handle error
......
......@@ -10,41 +10,116 @@
# limitations under the License.
# ==============================================================================
import logging
import uuid
from typing import Dict, Optional
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
from sglang.srt.managers.schedule_batch import Req
class SessionReqNode:
def __init__(self, req, parent=None, childs=None):
self.req = req
self.parent = parent
if parent is not None:
parent.childs.append(self)
self.childs = [] if not childs else childs
def clear_childs(self, req_dict):
for req_node in self.childs:
req_node.clear(req_dict)
self.childs = []
def clear(self, req_dict):
for req_node in self.childs:
req_node.clear(req_dict)
if self.req.finished_reason == None:
self.req.to_abort = True
del req_dict[self.req.rid]
def abort(self):
if self.req.finished_reason == None:
self.req.to_abort = True
def __str__(self):
return self._str_helper(self.req.rid)
def _str_helper(self, prefix=""):
if len(self.childs) == 0:
return prefix + "\n"
else:
origin_prefix = prefix
prefix += " -- " + self.childs[0].req.rid
ret = self.childs[0]._str_helper(prefix)
for child in self.childs[1:]:
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
ret += child._str_helper(prefix)
return ret
class Session:
def __init__(self, capacity_of_str_len: int, session_id: str = None):
def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
self.capacity_of_str_len = capacity_of_str_len
self.reqs: List[Req] = []
self.req_nodes: Dict[str, SessionReqNode] = {}
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
if req.session_rid is not None:
while len(self.reqs) > 0:
if self.reqs[-1].rid == req.session_rid:
break
self.reqs = self.reqs[:-1]
assert req.session_params is not None
session_params = req.session_params
last_req_node = None
last_req = None
abort = False
if session_params.replace:
if session_params.rid is None:
for _, req_node in self.req_nodes.items():
req_node.clear(self.req_nodes)
else:
if session_params.rid not in self.req_nodes:
abort = True
else:
last_req_node = self.req_nodes[session_params.rid]
last_req_node.abort()
last_req = last_req_node.req
last_req_node.clear_childs(self.req_nodes)
else:
self.reqs = []
if len(self.reqs) > 0:
if session_params.rid is not None:
if session_params.rid not in self.req_nodes:
abort = True
else:
last_req_node = self.req_nodes[session_params.rid]
last_req = last_req_node.req
if not last_req.finished():
logging.warning(
"The request in a session is appending to a request that hasn't finished."
)
abort = True
if last_req is not None:
# trim bos token if it is an append
if req.input_ids[0] == tokenizer.bos_token_id:
req.input_ids = req.input_ids[1:]
input_ids = (
self.reqs[-1].origin_input_ids
+ self.reqs[-1].output_ids[
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
last_req.origin_input_ids
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
)
if session_params.offset and session_params.offset != 0:
input_ids = input_ids[: session_params.offset] + req.input_ids
else:
input_ids += req.input_ids
input_ids_unpadded = (
self.reqs[-1].origin_input_ids_unpadded
+ self.reqs[-1].output_ids[
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
last_req.origin_input_ids_unpadded
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
)
if session_params.offset and session_params.offset != 0:
input_ids_unpadded = (
input_ids_unpadded[: session_params.offset] + req.input_ids
)
else:
input_ids_unpadded += req.input_ids
else:
input_ids = req.input_ids
input_ids_unpadded = req.input_ids
......@@ -57,13 +132,13 @@ class Session:
lora_path=req.lora_path,
session_id=self.session_id,
)
if len(self.reqs) > 0:
new_req.image_inputs = self.reqs[-1].image_inputs
if last_req is not None:
new_req.image_inputs = last_req.image_inputs
new_req.tokenizer = tokenizer
if req.session_rid is not None and len(self.reqs) == 0:
new_req.finished_reason = FINISH_ABORT(
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
)
if abort:
new_req.to_abort = True
else:
self.reqs.append(new_req)
new_req_node = SessionReqNode(new_req, last_req_node)
self.req_nodes[req.rid] = new_req_node
return new_req
......@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
SessionParams,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
......@@ -264,8 +265,9 @@ class TokenizerManager:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
session_id = obj.session[0] if obj.session else None
session_rid = obj.session[1] if obj.session else None
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
if obj.input_ids is not None and len(input_ids) >= self.context_len:
raise ValueError(
......@@ -292,8 +294,7 @@ class TokenizerManager:
obj.stream,
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_id=session_id,
session_rid=session_rid,
session_params=session_params,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
......@@ -552,12 +553,16 @@ class TokenizerManager:
):
self.auto_create_handle_loop()
session_id = uuid.uuid4().hex
obj.session_id = session_id
if obj.session_id is None:
obj.session_id = uuid.uuid4().hex
elif obj.session_id in self.session_futures:
return None
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[session_id] = asyncio.Future()
session_id = await self.session_futures[session_id]
del self.session_futures[session_id]
self.session_futures[obj.session_id] = asyncio.Future()
session_id = await self.session_futures[obj.session_id]
del self.session_futures[obj.session_id]
return session_id
async def close_session(
......@@ -709,7 +714,7 @@ class TokenizerManager:
)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
recv_obj.session_id if recv_obj.success else None
)
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1:
......
......@@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
try:
session_id = await tokenizer_manager.open_session(obj, request)
if session_id is None:
raise Exception(
"Failed to open the session. Check if a session with the same id is still open."
)
return session_id
except Exception as e:
return _create_error_response(e)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment