Unverified Commit 5942dfc0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[feat] Add session control (#2073)

parent 63a395b9
...@@ -175,6 +175,7 @@ class DetokenizerManager: ...@@ -175,6 +175,7 @@ class DetokenizerManager:
output_strs=output_strs, output_strs=output_strs,
meta_info=recv_obj.meta_info, meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason, finished_reason=recv_obj.finished_reason,
session_ids=recv_obj.session_ids,
) )
) )
......
...@@ -56,6 +56,10 @@ class GenerateReqInput: ...@@ -56,6 +56,10 @@ class GenerateReqInput:
# LoRA related # LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session id info for continual prompting
session_id: Optional[Union[List[str], str]] = None
session_rid: Optional[Union[List[str], str]] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
...@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput: ...@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
# LoRA related # LoRA related
lora_path: Optional[str] = None # None means just use the base model lora_path: Optional[str] = None # None means just use the base model
# Session id info for continual prompting
session_id: Optional[int] = None
session_rid: Optional[str] = None
@dataclass @dataclass
class EmbeddingReqInput: class EmbeddingReqInput:
...@@ -293,6 +301,8 @@ class BatchTokenIDOut: ...@@ -293,6 +301,8 @@ class BatchTokenIDOut:
meta_info: List[Dict] meta_info: List[Dict]
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
no_stop_trim: List[bool] no_stop_trim: List[bool]
# The updated session unique id
session_ids: List[str]
@dataclass @dataclass
...@@ -305,6 +315,8 @@ class BatchStrOut: ...@@ -305,6 +315,8 @@ class BatchStrOut:
meta_info: List[Dict] meta_info: List[Dict]
# The finish reason # The finish reason
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
# The update session unique id
session_ids: List[str]
@dataclass @dataclass
...@@ -357,3 +369,18 @@ class GetMemPoolSizeReq: ...@@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
@dataclass @dataclass
class GetMemPoolSizeReqOutput: class GetMemPoolSizeReqOutput:
size: int size: int
@dataclass
class OpenSessionReqInput:
capacity_of_str_len: int
@dataclass
class CloseSessionReqInput:
session_id: str
@dataclass
class OpenSessionReqOutput:
session_id: str
...@@ -180,6 +180,7 @@ class Req: ...@@ -180,6 +180,7 @@ class Req:
origin_input_ids: Tuple[int], origin_input_ids: Tuple[int],
sampling_params: SamplingParams, sampling_params: SamplingParams,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
session_id: Optional[str] = None,
): ):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
...@@ -188,6 +189,8 @@ class Req: ...@@ -188,6 +189,8 @@ class Req:
self.origin_input_ids = origin_input_ids self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.session_id = session_id
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.lora_path = lora_path self.lora_path = lora_path
......
...@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import ( ...@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
BatchTokenIDOut, BatchTokenIDOut,
CloseSessionReqInput,
FlushCacheReq, FlushCacheReq,
GetMemPoolSizeReq, GetMemPoolSizeReq,
GetMemPoolSizeReqOutput, GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq, ProfileReq,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import ( ...@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder, PrefillAdder,
SchedulePolicy, SchedulePolicy,
) )
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
...@@ -106,6 +110,9 @@ class Scheduler: ...@@ -106,6 +110,9 @@ class Scheduler:
self.skip_tokenizer_init = server_args.skip_tokenizer_init self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
# Session info
self.sessions = {}
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
...@@ -509,6 +516,11 @@ class Scheduler: ...@@ -509,6 +516,11 @@ class Scheduler:
self.start_profile() self.start_profile()
else: else:
self.stop_profile() self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput):
session_id = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req)
elif isinstance(recv_req, GetMemPoolSizeReq): elif isinstance(recv_req, GetMemPoolSizeReq):
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
GetMemPoolSizeReqOutput(self.max_total_num_tokens) GetMemPoolSizeReqOutput(self.max_total_num_tokens)
...@@ -520,14 +532,30 @@ class Scheduler: ...@@ -520,14 +532,30 @@ class Scheduler:
self, self,
recv_req: TokenizedGenerateReqInput, recv_req: TokenizedGenerateReqInput,
): ):
req = Req( if recv_req.session_id is None or recv_req.session_id not in self.sessions:
recv_req.rid, req = Req(
recv_req.input_text, recv_req.rid,
recv_req.input_ids, recv_req.input_text,
recv_req.sampling_params, recv_req.input_ids,
lora_path=recv_req.lora_path, recv_req.sampling_params,
) lora_path=recv_req.lora_path,
req.tokenizer = self.tokenizer )
req.tokenizer = self.tokenizer
if recv_req.session_id is not None:
req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_id} does not exist"
)
self.waiting_queue.append(req)
return
else:
# Handle sessions
session = self.sessions[recv_req.session_id]
req, new_session_id = session.create_req(recv_req, self.tokenizer)
del self.sessions[recv_req.session_id]
self.sessions[new_session_id] = session
if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req)
return
# Image inputs # Image inputs
if recv_req.image_inputs is not None: if recv_req.image_inputs is not None:
...@@ -1151,6 +1179,7 @@ class Scheduler: ...@@ -1151,6 +1179,7 @@ class Scheduler:
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_no_stop_trim = [] output_no_stop_trim = []
output_session_ids = []
else: # embedding or reward model else: # embedding or reward model
output_embeddings = [] output_embeddings = []
...@@ -1178,6 +1207,7 @@ class Scheduler: ...@@ -1178,6 +1207,7 @@ class Scheduler:
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
) )
output_no_stop_trim.append(req.sampling_params.no_stop_trim) output_no_stop_trim.append(req.sampling_params.no_stop_trim)
output_session_ids.append(req.session_id)
meta_info = { meta_info = {
"prompt_tokens": len(req.origin_input_ids), "prompt_tokens": len(req.origin_input_ids),
...@@ -1228,6 +1258,7 @@ class Scheduler: ...@@ -1228,6 +1258,7 @@ class Scheduler:
output_meta_info, output_meta_info,
output_finished_reason, output_finished_reason,
output_no_stop_trim, output_no_stop_trim,
output_session_ids,
) )
) )
else: # embedding or reward model else: # embedding or reward model
...@@ -1330,6 +1361,25 @@ class Scheduler: ...@@ -1330,6 +1361,25 @@ class Scheduler:
) )
logger.info("Profiler is done") logger.info("Profiler is done")
def open_session(self, recv_req: OpenSessionReqInput) -> str:
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
return session_id
def close_session(self, recv_req: CloseSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id not in self.sessions:
logger.warning(f"session id {session_id} does not exist, cannot delete.")
else:
del self.sessions[session_id]
def run_scheduler_process( def run_scheduler_process(
server_args: ServerArgs, server_args: ServerArgs,
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import copy
import uuid
from dataclasses import dataclass
from typing import Optional
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
class Session:
def __init__(self, capacity_of_str_len: int, session_id: str = None):
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
self.capacity_of_str_len = capacity_of_str_len
self.reqs: List[Req] = []
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
# renew session id
self.session_id = uuid.uuid4().hex
if req.session_rid is not None:
while len(self.reqs) > 0:
if self.reqs[-1].rid == req.session_rid:
break
self.reqs = self.reqs[:-1]
if len(self.reqs) > 0:
input_ids = (
self.reqs[-1].origin_input_ids
+ self.reqs[-1].output_ids[
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
)
else:
input_ids = req.input_ids
new_req = Req(
req.rid,
None,
input_ids,
req.sampling_params,
lora_path=req.lora_path,
session_id=self.session_id,
)
new_req.tokenizer = tokenizer
if req.session_rid is not None and len(self.reqs) == 0:
new_req.finished_reason = FINISH_ABORT(
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
)
else:
self.reqs.append(new_req)
return new_req, self.session_id
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import signal import signal
import sys import sys
import time import time
import uuid
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import fastapi import fastapi
...@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import ( ...@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
CloseSessionReqInput,
EmbeddingReqInput, EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
GetMemPoolSizeReq, GetMemPoolSizeReq,
GetMemPoolSizeReqOutput, GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq, ProfileReq,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -146,6 +150,9 @@ class TokenizerManager: ...@@ -146,6 +150,9 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock() self.model_update_lock = asyncio.Lock()
self.model_update_result = None self.model_update_result = None
# For session info
self.session_futures = {} # session_id -> asyncio event
# Others # Others
self.gracefully_exit = False self.gracefully_exit = False
...@@ -211,6 +218,8 @@ class TokenizerManager: ...@@ -211,6 +218,8 @@ class TokenizerManager:
return_logprob = obj.return_logprob return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num top_logprobs_num = obj.top_logprobs_num
session_id = obj.session_id
session_rid = obj.session_rid
if len(input_ids) >= self.context_len: if len(input_ids) >= self.context_len:
raise ValueError( raise ValueError(
...@@ -236,6 +245,8 @@ class TokenizerManager: ...@@ -236,6 +245,8 @@ class TokenizerManager:
top_logprobs_num, top_logprobs_num,
obj.stream, obj.stream,
obj.lora_path, obj.lora_path,
session_id=session_id,
session_rid=session_rid,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
...@@ -451,6 +462,26 @@ class TokenizerManager: ...@@ -451,6 +462,26 @@ class TokenizerManager:
else: else:
return False, "Another update is in progress. Please try again later." return False, "Another update is in progress. Please try again later."
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
session_id = uuid.uuid4().hex
obj.session_id = session_id
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[session_id] = asyncio.Future()
session_id = await self.session_futures[session_id]
del self.session_futures[session_id]
return session_id
async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
assert not self.to_create_loop, "close session should not be the first request"
await self.send_to_scheduler.send_pyobj(obj)
def create_abort_task(self, obj: GenerateReqInput): def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected. # Abort the request if the client is disconnected.
async def abort_request(): async def abort_request():
...@@ -521,6 +552,11 @@ class TokenizerManager: ...@@ -521,6 +552,11 @@ class TokenizerManager:
if len(self.mem_pool_size_tmp) == self.server_args.dp_size: if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp) self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
continue
assert isinstance( assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
...@@ -536,11 +572,13 @@ class TokenizerManager: ...@@ -536,11 +572,13 @@ class TokenizerManager:
out_dict = { out_dict = {
"text": recv_obj.output_strs[i], "text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i], "meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
} }
elif isinstance(recv_obj, BatchTokenIDOut): elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = { out_dict = {
"token_ids": recv_obj.output_ids[i], "token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i], "meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
} }
else: else:
assert isinstance(recv_obj, BatchEmbeddingOut) assert isinstance(recv_obj, BatchEmbeddingOut)
......
...@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import ( ...@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
) )
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
CloseSessionReqInput,
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
OpenSessionReqInput,
UpdateWeightReqInput, UpdateWeightReqInput,
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
...@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): ...@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
) )
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
try:
session_id = await tokenizer_manager.open_session(obj, request)
return session_id
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request):
"""Close the session"""
try:
await tokenizer_manager.close_session(obj, request)
return Response(status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@time_func_latency @time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request): async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request.""" """Handle a generate request."""
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# FIXME: Make it a CI test
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
url = "http://localhost:30000"
# Open a session
response = requests.post(
url + "/open_session",
json={"capacity_of_str_len": 1000},
)
session_id = response.json()
print("session_id", session_id, "\n")
# Prefill only
prompt = "chunk 1"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 0,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
# Generate
prompt = "Chunk 2"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
rid = response.json()["meta_info"]["id"]
# Generate
prompt = "Chunk 3"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 2,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
rid_to_del = response.json()["meta_info"]["id"]
# Interrupt and re-generate
prompt = "Chunk 4"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"session_rid": rid,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
# Query a session based on a deleted request, should see finish reason abort
prompt = "Chunk 4"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"session_rid": rid_to_del,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
# Close session
ret = requests.post(
url + "/close_session",
json={"session_id": session_id},
)
print(ret, "\n")
# Query a deleted session, should see finish reason abort
prompt = "chunk 1"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 0,
},
},
)
print(response.json(), "\n")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment