Unverified Commit e0e09fce authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Session] Update session control interface (#2635)

parent 9c05c689
...@@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason ...@@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
@dataclass
class SessionParams:
id: Optional[str] = None
rid: Optional[str] = None
offset: Optional[int] = None
replace: Optional[bool] = None
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
...@@ -58,10 +66,8 @@ class GenerateReqInput: ...@@ -58,10 +66,8 @@ class GenerateReqInput:
# LoRA related # LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session id info for continual prompting # Session info for continual prompting
session: Optional[ session_params: Optional[Union[List[Dict], Dict]] = None
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
if ( if (
...@@ -223,9 +229,8 @@ class TokenizedGenerateReqInput: ...@@ -223,9 +229,8 @@ class TokenizedGenerateReqInput:
# The input embeds # The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Session id info for continual prompting # Session info for continual prompting
session_id: Optional[str] = None session_params: Optional[SessionParams] = None
session_rid: Optional[str] = None
@dataclass @dataclass
...@@ -468,6 +473,7 @@ class ProfileReq(Enum): ...@@ -468,6 +473,7 @@ class ProfileReq(Enum):
@dataclass @dataclass
class OpenSessionReqInput: class OpenSessionReqInput:
capacity_of_str_len: int capacity_of_str_len: int
session_id: Optional[str] = None
@dataclass @dataclass
...@@ -477,4 +483,5 @@ class CloseSessionReqInput: ...@@ -477,4 +483,5 @@ class CloseSessionReqInput:
@dataclass @dataclass
class OpenSessionReqOutput: class OpenSessionReqOutput:
session_id: str session_id: Optional[str]
success: bool
...@@ -22,7 +22,7 @@ import warnings ...@@ -22,7 +22,7 @@ import warnings
from collections import deque from collections import deque
from concurrent import futures from concurrent import futures
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
import psutil import psutil
import setproctitle import setproctitle
...@@ -498,8 +498,10 @@ class Scheduler: ...@@ -498,8 +498,10 @@ class Scheduler:
else: else:
self.stop_profile() self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput): elif isinstance(recv_req, OpenSessionReqInput):
session_id = self.open_session(recv_req) session_id, success = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id)) self.send_to_tokenizer.send_pyobj(
OpenSessionReqOutput(session_id=session_id, success=success)
)
elif isinstance(recv_req, CloseSessionReqInput): elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req) self.close_session(recv_req)
else: else:
...@@ -510,7 +512,11 @@ class Scheduler: ...@@ -510,7 +512,11 @@ class Scheduler:
recv_req: TokenizedGenerateReqInput, recv_req: TokenizedGenerateReqInput,
): ):
# Create a new request # Create a new request
if recv_req.session_id is None or recv_req.session_id not in self.sessions: if (
recv_req.session_params is None
or recv_req.session_params.id is None
or recv_req.session_params.id not in self.sessions
):
if recv_req.input_embeds is not None: if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds # Generate fake input_ids based on the length of input_embeds
...@@ -532,15 +538,18 @@ class Scheduler: ...@@ -532,15 +538,18 @@ class Scheduler:
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
if recv_req.session_id is not None: if (
recv_req.session_params is not None
and recv_req.session_params.id is not None
):
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_id} does not exist" f"Invalid request: session id {recv_req.session_params.id} does not exist"
) )
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
else: else:
# Create a new request from a previsou session # Create a new request from a previous session
session = self.sessions[recv_req.session_id] session = self.sessions[recv_req.session_params.id]
req = session.create_req(recv_req, self.tokenizer) req = session.create_req(recv_req, self.tokenizer)
if isinstance(req.finished_reason, FINISH_ABORT): if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req) self.waiting_queue.append(req)
...@@ -1500,16 +1509,20 @@ class Scheduler: ...@@ -1500,16 +1509,20 @@ class Scheduler:
) )
logger.info("Profiler is done") logger.info("Profiler is done")
def open_session(self, recv_req: OpenSessionReqInput) -> str: def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
# handle error # handle error
session_id = recv_req.session_id session_id = recv_req.session_id
if session_id in self.sessions: if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.") logger.warning(f"session id {session_id} already exist, cannot open.")
return session_id, False
elif session_id is None:
logger.warning(f"session id is None, cannot open.")
return session_id, False
else: else:
self.sessions[session_id] = Session( self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id recv_req.capacity_of_str_len, session_id
) )
return session_id return session_id, True
def close_session(self, recv_req: CloseSessionReqInput): def close_session(self, recv_req: CloseSessionReqInput):
# handle error # handle error
......
...@@ -10,41 +10,116 @@ ...@@ -10,41 +10,116 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import logging
import uuid import uuid
from typing import Dict, Optional
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req from sglang.srt.managers.schedule_batch import Req
class SessionReqNode:
def __init__(self, req, parent=None, childs=None):
self.req = req
self.parent = parent
if parent is not None:
parent.childs.append(self)
self.childs = [] if not childs else childs
def clear_childs(self, req_dict):
for req_node in self.childs:
req_node.clear(req_dict)
self.childs = []
def clear(self, req_dict):
for req_node in self.childs:
req_node.clear(req_dict)
if self.req.finished_reason == None:
self.req.to_abort = True
del req_dict[self.req.rid]
def abort(self):
if self.req.finished_reason == None:
self.req.to_abort = True
def __str__(self):
return self._str_helper(self.req.rid)
def _str_helper(self, prefix=""):
if len(self.childs) == 0:
return prefix + "\n"
else:
origin_prefix = prefix
prefix += " -- " + self.childs[0].req.rid
ret = self.childs[0]._str_helper(prefix)
for child in self.childs[1:]:
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
ret += child._str_helper(prefix)
return ret
class Session: class Session:
def __init__(self, capacity_of_str_len: int, session_id: str = None): def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
self.session_id = session_id if session_id is not None else uuid.uuid4().hex self.session_id = session_id if session_id is not None else uuid.uuid4().hex
self.capacity_of_str_len = capacity_of_str_len self.capacity_of_str_len = capacity_of_str_len
self.reqs: List[Req] = [] self.req_nodes: Dict[str, SessionReqNode] = {}
def create_req(self, req: TokenizedGenerateReqInput, tokenizer): def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
if req.session_rid is not None: assert req.session_params is not None
while len(self.reqs) > 0: session_params = req.session_params
if self.reqs[-1].rid == req.session_rid:
break last_req_node = None
self.reqs = self.reqs[:-1] last_req = None
abort = False
if session_params.replace:
if session_params.rid is None:
for _, req_node in self.req_nodes.items():
req_node.clear(self.req_nodes)
else:
if session_params.rid not in self.req_nodes:
abort = True
else:
last_req_node = self.req_nodes[session_params.rid]
last_req_node.abort()
last_req = last_req_node.req
last_req_node.clear_childs(self.req_nodes)
else: else:
self.reqs = [] if session_params.rid is not None:
if len(self.reqs) > 0: if session_params.rid not in self.req_nodes:
abort = True
else:
last_req_node = self.req_nodes[session_params.rid]
last_req = last_req_node.req
if not last_req.finished():
logging.warning(
"The request in a session is appending to a request that hasn't finished."
)
abort = True
if last_req is not None:
# trim bos token if it is an append
if req.input_ids[0] == tokenizer.bos_token_id:
req.input_ids = req.input_ids[1:]
input_ids = ( input_ids = (
self.reqs[-1].origin_input_ids last_req.origin_input_ids
+ self.reqs[-1].output_ids[ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
) )
if session_params.offset and session_params.offset != 0:
input_ids = input_ids[: session_params.offset] + req.input_ids
else:
input_ids += req.input_ids
input_ids_unpadded = ( input_ids_unpadded = (
self.reqs[-1].origin_input_ids_unpadded last_req.origin_input_ids_unpadded
+ self.reqs[-1].output_ids[ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
) )
if session_params.offset and session_params.offset != 0:
input_ids_unpadded = (
input_ids_unpadded[: session_params.offset] + req.input_ids
)
else:
input_ids_unpadded += req.input_ids
else: else:
input_ids = req.input_ids input_ids = req.input_ids
input_ids_unpadded = req.input_ids input_ids_unpadded = req.input_ids
...@@ -57,13 +132,13 @@ class Session: ...@@ -57,13 +132,13 @@ class Session:
lora_path=req.lora_path, lora_path=req.lora_path,
session_id=self.session_id, session_id=self.session_id,
) )
if len(self.reqs) > 0: if last_req is not None:
new_req.image_inputs = self.reqs[-1].image_inputs new_req.image_inputs = last_req.image_inputs
new_req.tokenizer = tokenizer new_req.tokenizer = tokenizer
if req.session_rid is not None and len(self.reqs) == 0: if abort:
new_req.finished_reason = FINISH_ABORT( new_req.to_abort = True
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
)
else: else:
self.reqs.append(new_req) new_req_node = SessionReqNode(new_req, last_req_node)
self.req_nodes[req.rid] = new_req_node
return new_req return new_req
...@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
SessionParams,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
...@@ -264,8 +265,9 @@ class TokenizerManager: ...@@ -264,8 +265,9 @@ class TokenizerManager:
return_logprob = obj.return_logprob return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num top_logprobs_num = obj.top_logprobs_num
session_id = obj.session[0] if obj.session else None session_params = (
session_rid = obj.session[1] if obj.session else None SessionParams(**obj.session_params) if obj.session_params else None
)
if obj.input_ids is not None and len(input_ids) >= self.context_len: if obj.input_ids is not None and len(input_ids) >= self.context_len:
raise ValueError( raise ValueError(
...@@ -292,8 +294,7 @@ class TokenizerManager: ...@@ -292,8 +294,7 @@ class TokenizerManager:
obj.stream, obj.stream,
lora_path=obj.lora_path, lora_path=obj.lora_path,
input_embeds=input_embeds, input_embeds=input_embeds,
session_id=session_id, session_params=session_params,
session_rid=session_rid,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
...@@ -552,12 +553,16 @@ class TokenizerManager: ...@@ -552,12 +553,16 @@ class TokenizerManager:
): ):
self.auto_create_handle_loop() self.auto_create_handle_loop()
session_id = uuid.uuid4().hex if obj.session_id is None:
obj.session_id = session_id obj.session_id = uuid.uuid4().hex
elif obj.session_id in self.session_futures:
return None
self.send_to_scheduler.send_pyobj(obj) self.send_to_scheduler.send_pyobj(obj)
self.session_futures[session_id] = asyncio.Future()
session_id = await self.session_futures[session_id] self.session_futures[obj.session_id] = asyncio.Future()
del self.session_futures[session_id] session_id = await self.session_futures[obj.session_id]
del self.session_futures[obj.session_id]
return session_id return session_id
async def close_session( async def close_session(
...@@ -709,7 +714,7 @@ class TokenizerManager: ...@@ -709,7 +714,7 @@ class TokenizerManager:
) )
elif isinstance(recv_obj, OpenSessionReqOutput): elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id recv_obj.session_id if recv_obj.success else None
) )
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput): elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
......
...@@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request): ...@@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id.""" """Open a session, and return its unique session id."""
try: try:
session_id = await tokenizer_manager.open_session(obj, request) session_id = await tokenizer_manager.open_session(obj, request)
if session_id is None:
raise Exception(
"Failed to open the session. Check if a session with the same id is still open."
)
return session_id return session_id
except Exception as e: except Exception as e:
return _create_error_response(e) return _create_error_response(e)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment