Unverified Commit 5942dfc0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[feat] Add session control (#2073)

parent 63a395b9
......@@ -175,6 +175,7 @@ class DetokenizerManager:
output_strs=output_strs,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
session_ids=recv_obj.session_ids,
)
)
......
......@@ -56,6 +56,10 @@ class GenerateReqInput:
# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session id info for continual prompting
session_id: Optional[Union[List[str], str]] = None
session_rid: Optional[Union[List[str], str]] = None
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
......@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
# LoRA related
lora_path: Optional[str] = None # None means just use the base model
# Session id info for continual prompting
session_id: Optional[int] = None
session_rid: Optional[str] = None
@dataclass
class EmbeddingReqInput:
......@@ -293,6 +301,8 @@ class BatchTokenIDOut:
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
no_stop_trim: List[bool]
# The updated session unique id
session_ids: List[str]
@dataclass
......@@ -305,6 +315,8 @@ class BatchStrOut:
meta_info: List[Dict]
# The finish reason
finished_reason: List[BaseFinishReason]
# The update session unique id
session_ids: List[str]
@dataclass
......@@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
@dataclass
class GetMemPoolSizeReqOutput:
size: int
@dataclass
class OpenSessionReqInput:
capacity_of_str_len: int
@dataclass
class CloseSessionReqInput:
session_id: str
@dataclass
class OpenSessionReqOutput:
session_id: str
......@@ -180,6 +180,7 @@ class Req:
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
lora_path: Optional[str] = None,
session_id: Optional[str] = None,
):
# Input and output info
self.rid = rid
......@@ -188,6 +189,8 @@ class Req:
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.session_id = session_id
self.sampling_params = sampling_params
self.lora_path = lora_path
......
......@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchTokenIDOut,
CloseSessionReqInput,
FlushCacheReq,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
......@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
......@@ -106,6 +110,9 @@ class Scheduler:
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
# Session info
self.sessions = {}
# Init inter-process communication
context = zmq.Context(2)
......@@ -509,6 +516,11 @@ class Scheduler:
self.start_profile()
else:
self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput):
session_id = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req)
elif isinstance(recv_req, GetMemPoolSizeReq):
self.send_to_tokenizer.send_pyobj(
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
......@@ -520,6 +532,7 @@ class Scheduler:
self,
recv_req: TokenizedGenerateReqInput,
):
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
req = Req(
recv_req.rid,
recv_req.input_text,
......@@ -528,6 +541,21 @@ class Scheduler:
lora_path=recv_req.lora_path,
)
req.tokenizer = self.tokenizer
if recv_req.session_id is not None:
req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_id} does not exist"
)
self.waiting_queue.append(req)
return
else:
# Handle sessions
session = self.sessions[recv_req.session_id]
req, new_session_id = session.create_req(recv_req, self.tokenizer)
del self.sessions[recv_req.session_id]
self.sessions[new_session_id] = session
if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req)
return
# Image inputs
if recv_req.image_inputs is not None:
......@@ -1151,6 +1179,7 @@ class Scheduler:
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_stop_trim = []
output_session_ids = []
else: # embedding or reward model
output_embeddings = []
......@@ -1178,6 +1207,7 @@ class Scheduler:
req.sampling_params.spaces_between_special_tokens
)
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
output_session_ids.append(req.session_id)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
......@@ -1228,6 +1258,7 @@ class Scheduler:
output_meta_info,
output_finished_reason,
output_no_stop_trim,
output_session_ids,
)
)
else: # embedding or reward model
......@@ -1330,6 +1361,25 @@ class Scheduler:
)
logger.info("Profiler is done")
def open_session(self, recv_req: OpenSessionReqInput) -> str:
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
return session_id
def close_session(self, recv_req: CloseSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id not in self.sessions:
logger.warning(f"session id {session_id} does not exist, cannot delete.")
else:
del self.sessions[session_id]
def run_scheduler_process(
server_args: ServerArgs,
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import copy
import uuid
from dataclasses import dataclass
from typing import Optional
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
class Session:
def __init__(self, capacity_of_str_len: int, session_id: str = None):
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
self.capacity_of_str_len = capacity_of_str_len
self.reqs: List[Req] = []
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
# renew session id
self.session_id = uuid.uuid4().hex
if req.session_rid is not None:
while len(self.reqs) > 0:
if self.reqs[-1].rid == req.session_rid:
break
self.reqs = self.reqs[:-1]
if len(self.reqs) > 0:
input_ids = (
self.reqs[-1].origin_input_ids
+ self.reqs[-1].output_ids[
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
)
else:
input_ids = req.input_ids
new_req = Req(
req.rid,
None,
input_ids,
req.sampling_params,
lora_path=req.lora_path,
session_id=self.session_id,
)
new_req.tokenizer = tokenizer
if req.session_rid is not None and len(self.reqs) == 0:
new_req.finished_reason = FINISH_ABORT(
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
)
else:
self.reqs.append(new_req)
return new_req, self.session_id
......@@ -23,6 +23,7 @@ import os
import signal
import sys
import time
import uuid
from typing import Dict, List, Optional, Tuple, Union
import fastapi
......@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
CloseSessionReqInput,
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
......@@ -146,6 +150,9 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
# For session info
self.session_futures = {} # session_id -> asyncio event
# Others
self.gracefully_exit = False
......@@ -211,6 +218,8 @@ class TokenizerManager:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
session_id = obj.session_id
session_rid = obj.session_rid
if len(input_ids) >= self.context_len:
raise ValueError(
......@@ -236,6 +245,8 @@ class TokenizerManager:
top_logprobs_num,
obj.stream,
obj.lora_path,
session_id=session_id,
session_rid=session_rid,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
......@@ -451,6 +462,26 @@ class TokenizerManager:
else:
return False, "Another update is in progress. Please try again later."
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
session_id = uuid.uuid4().hex
obj.session_id = session_id
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[session_id] = asyncio.Future()
session_id = await self.session_futures[session_id]
del self.session_futures[session_id]
return session_id
async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
assert not self.to_create_loop, "close session should not be the first request"
await self.send_to_scheduler.send_pyobj(obj)
def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
async def abort_request():
......@@ -521,6 +552,11 @@ class TokenizerManager:
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
continue
assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
......@@ -536,11 +572,13 @@ class TokenizerManager:
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
......
......@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
)
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
CloseSessionReqInput,
EmbeddingReqInput,
GenerateReqInput,
OpenSessionReqInput,
UpdateWeightReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
......@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
)
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
try:
session_id = await tokenizer_manager.open_session(obj, request)
return session_id
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request):
"""Close the session"""
try:
await tokenizer_manager.close_session(obj, request)
return Response(status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# FIXME: Make it a CI test
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
url = "http://localhost:30000"
# Open a session
response = requests.post(
url + "/open_session",
json={"capacity_of_str_len": 1000},
)
session_id = response.json()
print("session_id", session_id, "\n")
# Prefill only
prompt = "chunk 1"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 0,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
# Generate
prompt = "Chunk 2"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
rid = response.json()["meta_info"]["id"]
# Generate
prompt = "Chunk 3"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 2,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
rid_to_del = response.json()["meta_info"]["id"]
# Interrupt and re-generate
prompt = "Chunk 4"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"session_rid": rid,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
# Query a session based on a deleted request, should see finish reason abort
prompt = "Chunk 4"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"session_rid": rid_to_del,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
# Close session
ret = requests.post(
url + "/close_session",
json={"session_id": session_id},
)
print(ret, "\n")
# Query a deleted session, should see finish reason abort
prompt = "chunk 1"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 0,
},
},
)
print(response.json(), "\n")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment