Unverified Commit e0e09fce authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Session] Update session control interface (#2635)

parent 9c05c689
...@@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason ...@@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
@dataclass
class SessionParams:
id: Optional[str] = None
rid: Optional[str] = None
offset: Optional[int] = None
replace: Optional[bool] = None
@dataclass @dataclass
class GenerateReqInput: class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
...@@ -58,10 +66,8 @@ class GenerateReqInput: ...@@ -58,10 +66,8 @@ class GenerateReqInput:
# LoRA related # LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session id info for continual prompting # Session info for continual prompting
session: Optional[ session_params: Optional[Union[List[Dict], Dict]] = None
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
if ( if (
...@@ -223,9 +229,8 @@ class TokenizedGenerateReqInput: ...@@ -223,9 +229,8 @@ class TokenizedGenerateReqInput:
# The input embeds # The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Session id info for continual prompting # Session info for continual prompting
session_id: Optional[str] = None session_params: Optional[SessionParams] = None
session_rid: Optional[str] = None
@dataclass @dataclass
...@@ -468,6 +473,7 @@ class ProfileReq(Enum): ...@@ -468,6 +473,7 @@ class ProfileReq(Enum):
@dataclass @dataclass
class OpenSessionReqInput: class OpenSessionReqInput:
capacity_of_str_len: int capacity_of_str_len: int
session_id: Optional[str] = None
@dataclass @dataclass
...@@ -477,4 +483,5 @@ class CloseSessionReqInput: ...@@ -477,4 +483,5 @@ class CloseSessionReqInput:
@dataclass @dataclass
class OpenSessionReqOutput: class OpenSessionReqOutput:
session_id: str session_id: Optional[str]
success: bool
...@@ -22,7 +22,7 @@ import warnings ...@@ -22,7 +22,7 @@ import warnings
from collections import deque from collections import deque
from concurrent import futures from concurrent import futures
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
import psutil import psutil
import setproctitle import setproctitle
...@@ -498,8 +498,10 @@ class Scheduler: ...@@ -498,8 +498,10 @@ class Scheduler:
else: else:
self.stop_profile() self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput): elif isinstance(recv_req, OpenSessionReqInput):
session_id = self.open_session(recv_req) session_id, success = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id)) self.send_to_tokenizer.send_pyobj(
OpenSessionReqOutput(session_id=session_id, success=success)
)
elif isinstance(recv_req, CloseSessionReqInput): elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req) self.close_session(recv_req)
else: else:
...@@ -510,7 +512,11 @@ class Scheduler: ...@@ -510,7 +512,11 @@ class Scheduler:
recv_req: TokenizedGenerateReqInput, recv_req: TokenizedGenerateReqInput,
): ):
# Create a new request # Create a new request
if recv_req.session_id is None or recv_req.session_id not in self.sessions: if (
recv_req.session_params is None
or recv_req.session_params.id is None
or recv_req.session_params.id not in self.sessions
):
if recv_req.input_embeds is not None: if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds # Generate fake input_ids based on the length of input_embeds
...@@ -532,15 +538,18 @@ class Scheduler: ...@@ -532,15 +538,18 @@ class Scheduler:
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
if recv_req.session_id is not None: if (
recv_req.session_params is not None
and recv_req.session_params.id is not None
):
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_id} does not exist" f"Invalid request: session id {recv_req.session_params.id} does not exist"
) )
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
else: else:
# Create a new request from a previsou session # Create a new request from a previous session
session = self.sessions[recv_req.session_id] session = self.sessions[recv_req.session_params.id]
req = session.create_req(recv_req, self.tokenizer) req = session.create_req(recv_req, self.tokenizer)
if isinstance(req.finished_reason, FINISH_ABORT): if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req) self.waiting_queue.append(req)
...@@ -1500,16 +1509,20 @@ class Scheduler: ...@@ -1500,16 +1509,20 @@ class Scheduler:
) )
logger.info("Profiler is done") logger.info("Profiler is done")
def open_session(self, recv_req: OpenSessionReqInput) -> str: def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
# handle error # handle error
session_id = recv_req.session_id session_id = recv_req.session_id
if session_id in self.sessions: if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.") logger.warning(f"session id {session_id} already exist, cannot open.")
return session_id, False
elif session_id is None:
logger.warning(f"session id is None, cannot open.")
return session_id, False
else: else:
self.sessions[session_id] = Session( self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id recv_req.capacity_of_str_len, session_id
) )
return session_id return session_id, True
def close_session(self, recv_req: CloseSessionReqInput): def close_session(self, recv_req: CloseSessionReqInput):
# handle error # handle error
......
...@@ -10,41 +10,116 @@ ...@@ -10,41 +10,116 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import logging
import uuid import uuid
from typing import Dict, Optional
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req from sglang.srt.managers.schedule_batch import Req
class SessionReqNode:
def __init__(self, req, parent=None, childs=None):
self.req = req
self.parent = parent
if parent is not None:
parent.childs.append(self)
self.childs = [] if not childs else childs
def clear_childs(self, req_dict):
for req_node in self.childs:
req_node.clear(req_dict)
self.childs = []
def clear(self, req_dict):
for req_node in self.childs:
req_node.clear(req_dict)
if self.req.finished_reason == None:
self.req.to_abort = True
del req_dict[self.req.rid]
def abort(self):
if self.req.finished_reason == None:
self.req.to_abort = True
def __str__(self):
return self._str_helper(self.req.rid)
def _str_helper(self, prefix=""):
if len(self.childs) == 0:
return prefix + "\n"
else:
origin_prefix = prefix
prefix += " -- " + self.childs[0].req.rid
ret = self.childs[0]._str_helper(prefix)
for child in self.childs[1:]:
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
ret += child._str_helper(prefix)
return ret
class Session: class Session:
def __init__(self, capacity_of_str_len: int, session_id: str = None): def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
self.session_id = session_id if session_id is not None else uuid.uuid4().hex self.session_id = session_id if session_id is not None else uuid.uuid4().hex
self.capacity_of_str_len = capacity_of_str_len self.capacity_of_str_len = capacity_of_str_len
self.reqs: List[Req] = [] self.req_nodes: Dict[str, SessionReqNode] = {}
def create_req(self, req: TokenizedGenerateReqInput, tokenizer): def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
if req.session_rid is not None: assert req.session_params is not None
while len(self.reqs) > 0: session_params = req.session_params
if self.reqs[-1].rid == req.session_rid:
break last_req_node = None
self.reqs = self.reqs[:-1] last_req = None
abort = False
if session_params.replace:
if session_params.rid is None:
for _, req_node in self.req_nodes.items():
req_node.clear(self.req_nodes)
else:
if session_params.rid not in self.req_nodes:
abort = True
else: else:
self.reqs = [] last_req_node = self.req_nodes[session_params.rid]
if len(self.reqs) > 0: last_req_node.abort()
last_req = last_req_node.req
last_req_node.clear_childs(self.req_nodes)
else:
if session_params.rid is not None:
if session_params.rid not in self.req_nodes:
abort = True
else:
last_req_node = self.req_nodes[session_params.rid]
last_req = last_req_node.req
if not last_req.finished():
logging.warning(
"The request in a session is appending to a request that hasn't finished."
)
abort = True
if last_req is not None:
# trim bos token if it is an append
if req.input_ids[0] == tokenizer.bos_token_id:
req.input_ids = req.input_ids[1:]
input_ids = ( input_ids = (
self.reqs[-1].origin_input_ids last_req.origin_input_ids
+ self.reqs[-1].output_ids[ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
) )
if session_params.offset and session_params.offset != 0:
input_ids = input_ids[: session_params.offset] + req.input_ids
else:
input_ids += req.input_ids
input_ids_unpadded = ( input_ids_unpadded = (
self.reqs[-1].origin_input_ids_unpadded last_req.origin_input_ids_unpadded
+ self.reqs[-1].output_ids[ + last_req.output_ids[: last_req.sampling_params.max_new_tokens]
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
) )
if session_params.offset and session_params.offset != 0:
input_ids_unpadded = (
input_ids_unpadded[: session_params.offset] + req.input_ids
)
else:
input_ids_unpadded += req.input_ids
else: else:
input_ids = req.input_ids input_ids = req.input_ids
input_ids_unpadded = req.input_ids input_ids_unpadded = req.input_ids
...@@ -57,13 +132,13 @@ class Session: ...@@ -57,13 +132,13 @@ class Session:
lora_path=req.lora_path, lora_path=req.lora_path,
session_id=self.session_id, session_id=self.session_id,
) )
if len(self.reqs) > 0: if last_req is not None:
new_req.image_inputs = self.reqs[-1].image_inputs new_req.image_inputs = last_req.image_inputs
new_req.tokenizer = tokenizer new_req.tokenizer = tokenizer
if req.session_rid is not None and len(self.reqs) == 0: if abort:
new_req.finished_reason = FINISH_ABORT( new_req.to_abort = True
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
)
else: else:
self.reqs.append(new_req) new_req_node = SessionReqNode(new_req, last_req_node)
self.req_nodes[req.rid] = new_req_node
return new_req return new_req
...@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
SessionParams,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
...@@ -264,8 +265,9 @@ class TokenizerManager: ...@@ -264,8 +265,9 @@ class TokenizerManager:
return_logprob = obj.return_logprob return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num top_logprobs_num = obj.top_logprobs_num
session_id = obj.session[0] if obj.session else None session_params = (
session_rid = obj.session[1] if obj.session else None SessionParams(**obj.session_params) if obj.session_params else None
)
if obj.input_ids is not None and len(input_ids) >= self.context_len: if obj.input_ids is not None and len(input_ids) >= self.context_len:
raise ValueError( raise ValueError(
...@@ -292,8 +294,7 @@ class TokenizerManager: ...@@ -292,8 +294,7 @@ class TokenizerManager:
obj.stream, obj.stream,
lora_path=obj.lora_path, lora_path=obj.lora_path,
input_embeds=input_embeds, input_embeds=input_embeds,
session_id=session_id, session_params=session_params,
session_rid=session_rid,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
...@@ -552,12 +553,16 @@ class TokenizerManager: ...@@ -552,12 +553,16 @@ class TokenizerManager:
): ):
self.auto_create_handle_loop() self.auto_create_handle_loop()
session_id = uuid.uuid4().hex if obj.session_id is None:
obj.session_id = session_id obj.session_id = uuid.uuid4().hex
elif obj.session_id in self.session_futures:
return None
self.send_to_scheduler.send_pyobj(obj) self.send_to_scheduler.send_pyobj(obj)
self.session_futures[session_id] = asyncio.Future()
session_id = await self.session_futures[session_id] self.session_futures[obj.session_id] = asyncio.Future()
del self.session_futures[session_id] session_id = await self.session_futures[obj.session_id]
del self.session_futures[obj.session_id]
return session_id return session_id
async def close_session( async def close_session(
...@@ -709,7 +714,7 @@ class TokenizerManager: ...@@ -709,7 +714,7 @@ class TokenizerManager:
) )
elif isinstance(recv_obj, OpenSessionReqOutput): elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id recv_obj.session_id if recv_obj.success else None
) )
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput): elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
......
...@@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request): ...@@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id.""" """Open a session, and return its unique session id."""
try: try:
session_id = await tokenizer_manager.open_session(obj, request) session_id = await tokenizer_manager.open_session(obj, request)
if session_id is None:
raise Exception(
"Failed to open the session. Check if a session with the same id is still open."
)
return session_id return session_id
except Exception as e: except Exception as e:
return _create_error_response(e) return _create_error_response(e)
......
""" """
Usage: Usage:
python3 -m unittest test_session_control.TestSessionControl.test_session_control python3 -m unittest test_session_control.TestSessionControl.test_session_control
python3 -m unittest test_session_control.TestSessionControl.test_session_control_with_branching
python3 -m unittest test_session_control.TestSessionControl.test_session_control_backtrack_with_abort
python3 -m unittest test_session_control.TestSessionControlVision.test_session_control python3 -m unittest test_session_control.TestSessionControlVision.test_session_control
""" """
import asyncio
import json
import unittest import unittest
import aiohttp
import requests import requests
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
...@@ -18,6 +23,10 @@ from sglang.test.test_utils import ( ...@@ -18,6 +23,10 @@ from sglang.test.test_utils import (
) )
def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
class TestSessionControl(unittest.TestCase): class TestSessionControl(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -31,15 +40,18 @@ class TestSessionControl(unittest.TestCase): ...@@ -31,15 +40,18 @@ class TestSessionControl(unittest.TestCase):
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_session_control(self): def test_session_control(self, gen_len=12):
chunks = [ chunks = [
"Let me tell you something about France.", "Let me tell you something about France.",
"The capital of France is", "The capital of France is",
"The population of the city is",
"A brief history about that city is", "A brief history about that city is",
"To plan a travel, the budget is",
] ]
tokenizer = get_tokenizer(self.model) tokenizer = get_tokenizer(self.model)
chunks_ids = [tokenizer.encode(x) for x in chunks] chunks_ids = [tokenizer.encode(x) for x in chunks]
for i in range(1, len(chunks_ids)):
if chunks_ids[i][0] == tokenizer.bos_token_id:
chunks_ids[i] = chunks_ids[i][1:]
# 1. using session control # 1. using session control
session_id = requests.post( session_id = requests.post(
...@@ -48,6 +60,13 @@ class TestSessionControl(unittest.TestCase): ...@@ -48,6 +60,13 @@ class TestSessionControl(unittest.TestCase):
).json() ).json()
rid = None rid = None
# open an existing session, should get session_id as None
response = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000, "session_id": session_id},
).json()
assert isinstance(response, dict) and "error" in response
first_rid = None first_rid = None
outputs_from_session = [] outputs_from_session = []
for i, chunk_ids in enumerate(chunks_ids): for i, chunk_ids in enumerate(chunks_ids):
...@@ -55,11 +74,16 @@ class TestSessionControl(unittest.TestCase): ...@@ -55,11 +74,16 @@ class TestSessionControl(unittest.TestCase):
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": chunk_ids, "input_ids": chunk_ids,
"session": [session_id, rid], "session_params": {
"id": session_id,
"rid": rid,
"offset": -1,
"replace": True,
},
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": ( "max_new_tokens": (
16 if i > 0 else 0 gen_len if i > 0 else 1
), # prefill only for the first chunk ), # prefill only for the first chunk
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
...@@ -77,10 +101,15 @@ class TestSessionControl(unittest.TestCase): ...@@ -77,10 +101,15 @@ class TestSessionControl(unittest.TestCase):
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": chunks_ids[-1], "input_ids": chunks_ids[-1],
"session": [session_id, first_rid], "session_params": {
"id": session_id,
"rid": first_rid,
"offset": -1,
"replace": True,
},
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": gen_len,
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
...@@ -93,10 +122,15 @@ class TestSessionControl(unittest.TestCase): ...@@ -93,10 +122,15 @@ class TestSessionControl(unittest.TestCase):
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": chunks_ids[-1], "input_ids": chunks_ids[-1],
"session": [session_id, rid], "session_params": {
"id": session_id,
"rid": rid,
"offset": -1,
"replace": True,
},
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": gen_len,
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
...@@ -115,10 +149,15 @@ class TestSessionControl(unittest.TestCase): ...@@ -115,10 +149,15 @@ class TestSessionControl(unittest.TestCase):
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": chunks_ids[-1], "input_ids": chunks_ids[-1],
"session": [session_id, first_rid], "session_params": {
"id": session_id,
"rid": first_rid,
"offset": -1,
"replace": True,
},
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": gen_len,
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
...@@ -127,6 +166,8 @@ class TestSessionControl(unittest.TestCase): ...@@ -127,6 +166,8 @@ class TestSessionControl(unittest.TestCase):
assert response["meta_info"]["finish_reason"]["type"] == "abort" assert response["meta_info"]["finish_reason"]["type"] == "abort"
# 2. not use session control # 2. not use session control
requests.post(self.base_url + "/flush_cache")
input_ids_first_req = None input_ids_first_req = None
input_ids = [] input_ids = []
outputs_normal = [] outputs_normal = []
...@@ -139,7 +180,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -139,7 +180,7 @@ class TestSessionControl(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": ( "max_new_tokens": (
16 if i > 0 else 0 gen_len if i > 0 else 1
), # prefill only for the first chunk ), # prefill only for the first chunk
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
...@@ -150,7 +191,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -150,7 +191,7 @@ class TestSessionControl(unittest.TestCase):
output_ids = tokenizer.encode(response["text"]) output_ids = tokenizer.encode(response["text"])
if output_ids[0] == tokenizer.bos_token_id: if output_ids[0] == tokenizer.bos_token_id:
output_ids = output_ids[1:] output_ids = output_ids[1:]
input_ids += output_ids input_ids += output_ids[:-1]
outputs_normal.append(response["text"]) outputs_normal.append(response["text"])
if i == 0: if i == 0:
input_ids_first_req = input_ids.copy() input_ids_first_req = input_ids.copy()
...@@ -162,7 +203,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -162,7 +203,7 @@ class TestSessionControl(unittest.TestCase):
"input_ids": input_ids_first_req, "input_ids": input_ids_first_req,
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": gen_len,
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
...@@ -176,6 +217,272 @@ class TestSessionControl(unittest.TestCase): ...@@ -176,6 +217,272 @@ class TestSessionControl(unittest.TestCase):
print(outputs_normal) print(outputs_normal)
assert outputs_from_session == outputs_normal assert outputs_from_session == outputs_normal
async def async_generate(self, payload):
url = self.base_url + "/generate"
async with aiohttp.ClientSession() as session:
async with session.post(url=url, json=payload) as response:
assert response.status == 200
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
if chunk == "[DONE]":
yield "", None, ""
else:
data = json.loads(chunk)
finish_reason = (
data["meta_info"]["finish_reason"]["type"]
if data["meta_info"]["finish_reason"]
else ""
)
yield data["text"], data["meta_info"]["id"], finish_reason
async def run_session_control_backtrack_with_abort(self, replace):
chunks = [
"Let me tell you something about France.",
"The capital of France is",
]
tokenizer = get_tokenizer(self.model)
chunks_ids = [tokenizer.encode(x) for x in chunks]
for i in range(1, len(chunks_ids)):
if chunks_ids[i][0] == tokenizer.bos_token_id:
chunks_ids[i] = chunks_ids[i][1:]
# 1. using session control
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
).json()
rid = None
payload = {
"input_ids": chunks_ids[0],
"session_params": {
"id": session_id,
"rid": rid,
"offset": -1,
"replace": True,
},
"sampling_params": {
"temperature": 0,
"max_new_tokens": 100,
"no_stop_trim": True,
"skip_special_tokens": False,
"ignore_eos": True,
},
"stream": True,
}
gen_so_far = ""
finish_reason = ""
second_output = ""
async for chunk, rid, finish_reason_chunk in self.async_generate(payload):
gen_so_far += chunk
if finish_reason == "":
finish_reason = finish_reason_chunk
if len(gen_so_far) > 50 and second_output == "":
payload2 = {
"input_ids": chunks_ids[1],
"session_params": {
"id": session_id,
"rid": rid,
"offset": 50,
"replace": replace,
},
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"stream": False,
"stream_output": True,
}
response = requests.post(
url=self.base_url + "/generate", json=payload2
).json()
second_output = response["text"]
if replace:
assert finish_reason == "abort"
print("first request output:")
print(gen_so_far)
print("second request output:")
print(second_output)
# close the session
ret = requests.post(
self.base_url + "/close_session",
json={"session_id": session_id},
)
assert ret.status_code == 200
if not replace:
assert response["meta_info"]["finish_reason"]["type"] == "abort"
else:
# 2. not using session control
output_ids = tokenizer.encode(gen_so_far)
if output_ids[0] == tokenizer.bos_token_id:
output_ids = output_ids[1:]
input_ids = chunks_ids[0] + output_ids
input_ids = input_ids[:50] + chunks_ids[1]
payload = {
"input_ids": input_ids,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"stream": False,
"stream_output": True,
}
response = requests.post(
url=self.base_url + "/generate", json=payload
).json()
output_no_session = response["text"]
print("second request output without session:")
print(output_no_session)
assert second_output == output_no_session
def test_session_control_backtrack_with_abort(self):
asyncio.run(self.run_session_control_backtrack_with_abort(replace=True))
asyncio.run(self.run_session_control_backtrack_with_abort(replace=False))
def run_session_control_with_branching(
self, root_prompt, chunks_per_step, gen_len=16
):
for x in chunks_per_step:
assert len(x) == len(chunks_per_step[0])
# 1. using session control
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
).json()
outputs_from_session = []
# send the root prompt
response = requests.post(
self.base_url + "/generate",
json={
"text": root_prompt,
"session_params": {
"id": session_id,
"rid": None,
"offset": 0,
"replace": False,
},
"sampling_params": {
"temperature": 0,
"max_new_tokens": gen_len,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
rid_per_branch = [response["meta_info"]["id"]] * len(chunks_per_step[0])
outputs_from_session.append(response["text"])
# send the prompts in branches
for chunks_for_branches in chunks_per_step:
for j, chunk in enumerate(chunks_for_branches):
response = requests.post(
self.base_url + "/generate",
json={
"text": chunk,
"session_params": {
"id": session_id,
"rid": rid_per_branch[j],
"offset": 0,
"replace": False,
},
"sampling_params": {
"temperature": 0,
"max_new_tokens": gen_len,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
rid = response["meta_info"]["id"]
rid_per_branch[j] = rid
outputs_from_session.append(response["text"])
# close the session
ret = requests.post(
self.base_url + "/close_session",
json={"session_id": session_id},
)
assert ret.status_code == 200
# 2. not use session control
requests.post(self.base_url + "/flush_cache")
outputs_normal = []
input_texts = [root_prompt] * len(chunks_per_step[0])
# send the root prompt
response = requests.post(
self.base_url + "/generate",
json={
"text": root_prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": gen_len,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
outputs_normal.append(response["text"])
input_texts = [x + response["text"] for x in input_texts]
# send the prompts in branches
for chunks_for_branches in chunks_per_step:
for j, chunk in enumerate(chunks_for_branches):
input_texts[j] += chunk
response = requests.post(
self.base_url + "/generate",
json={
"text": input_texts[j],
"sampling_params": {
"temperature": 0,
"max_new_tokens": gen_len,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
outputs_normal.append(response["text"])
input_texts[j] += response["text"]
print("====== outputs from chunked queries with session control: =======")
print(outputs_from_session)
print("====== outputs from normal queries: =======")
print(outputs_normal)
assert outputs_from_session == outputs_normal
def test_session_control_with_branching(self):
root_prompt = "First, let me explain in one sentence about AI"
chunks_per_step = [
[
"Then, briefly, the positive side of AI is",
"But, briefly, AI could be harmful to human",
],
["For example", "For example"],
]
self.run_session_control_with_branching(
root_prompt=root_prompt, chunks_per_step=chunks_per_step, gen_len=8
)
root_prompt = "I have three apples."
chunks_per_step = [
["I then give one apple to my friend", "My friend give me another apple."],
["I still have", "I now have"],
]
self.run_session_control_with_branching(
root_prompt=root_prompt, chunks_per_step=chunks_per_step, gen_len=8
)
class TestSessionControlVision(unittest.TestCase): class TestSessionControlVision(unittest.TestCase):
@classmethod @classmethod
...@@ -197,17 +504,25 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -197,17 +504,25 @@ class TestSessionControlVision(unittest.TestCase):
text_chunks = [ text_chunks = [
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
"<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<image>\nIs this image same with one of the previous images?<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n", "<|im_start|>user\n<image>\nIs this image same with one of the previous images?<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\nDescribe this image in a very short sentence.<|im_end|>\nassistant:",
] ]
image_chunks = [ image_chunks = [
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png",
] ]
assert len(text_chunks) == len(image_chunks) + 1
assert (
len(text_chunks) == len(image_chunks) + 2
) # the first and the last prompt does not contain images
tokenizer = get_tokenizer(self.model) tokenizer = get_tokenizer(self.model)
text_input_ids = [tokenizer.encode(x) for x in text_chunks] text_input_ids = [tokenizer.encode(x) for x in text_chunks]
for i in range(1, len(text_input_ids)):
if text_input_ids[i][0] == tokenizer.bos_token_id:
text_input_ids[i] = text_input_ids[i][1:]
gen_len = 32
# 1. using session control # 1. using session control
session_id = requests.post( session_id = requests.post(
...@@ -216,20 +531,32 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -216,20 +531,32 @@ class TestSessionControlVision(unittest.TestCase):
).json() ).json()
rid = None rid = None
# open an existing session, should get session_id as None
response = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000, "session_id": session_id},
).json()
assert isinstance(response, dict) and "error" in response
first_rid = None first_rid = None
outputs_from_session = [] outputs_from_session = []
for i in range(len(text_input_ids)): for i in range(len(text_input_ids[:-1])):
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": text_input_ids[i], "input_ids": text_input_ids[i],
"image_data": image_chunks[i - 1] if i > 0 else None, "image_data": image_chunks[i - 1] if i > 0 else None,
"modalities": ["multi-images"], "modalities": ["multi-images"],
"session": [session_id, rid], "session_params": {
"id": session_id,
"rid": rid,
"offset": 0,
"replace": True,
},
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": ( "max_new_tokens": (
16 if i > 0 else 0 gen_len if i > 0 else 0
), # prefill only for the first chunk ), # prefill only for the first chunk
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
...@@ -247,12 +574,15 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -247,12 +574,15 @@ class TestSessionControlVision(unittest.TestCase):
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": text_input_ids[-1], "input_ids": text_input_ids[-1],
"image_data": image_chunks[-1:], "session_params": {
"modalities": ["multi-images"], "id": session_id,
"session": [session_id, first_rid], "rid": first_rid,
"offset": 0,
"replace": True,
},
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": gen_len,
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
...@@ -265,12 +595,15 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -265,12 +595,15 @@ class TestSessionControlVision(unittest.TestCase):
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": text_input_ids[-1], "input_ids": text_input_ids[-1],
"image_data": image_chunks[-1:], "session_params": {
"modalities": ["multi-images"], "id": session_id,
"session": [session_id, rid], "rid": rid,
"offset": 0,
"replace": True,
},
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": gen_len,
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
...@@ -289,10 +622,15 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -289,10 +622,15 @@ class TestSessionControlVision(unittest.TestCase):
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": text_input_ids[-1], "input_ids": text_input_ids[-1],
"session": [session_id, first_rid], "session_params": {
"id": session_id,
"rid": first_rid,
"offset": 0,
"replace": True,
},
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": gen_len,
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
...@@ -306,7 +644,7 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -306,7 +644,7 @@ class TestSessionControlVision(unittest.TestCase):
input_ids_first_req = None input_ids_first_req = None
input_ids = [] input_ids = []
outputs_normal = [] outputs_normal = []
for i in range(len(text_input_ids)): for i in range(len(text_input_ids[:-1])):
input_ids += text_input_ids[i] input_ids += text_input_ids[i]
image_data = image_chunks[:i] if i > 0 else None image_data = image_chunks[:i] if i > 0 else None
response = requests.post( response = requests.post(
...@@ -318,7 +656,7 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -318,7 +656,7 @@ class TestSessionControlVision(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": ( "max_new_tokens": (
16 if i > 0 else 0 gen_len if i > 0 else 0
), # prefill only for the first chunk ), # prefill only for the first chunk
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
...@@ -339,11 +677,9 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -339,11 +677,9 @@ class TestSessionControlVision(unittest.TestCase):
self.base_url + "/generate", self.base_url + "/generate",
json={ json={
"input_ids": input_ids_first_req, "input_ids": input_ids_first_req,
"image_data": image_chunks[-1:],
"modalities": ["multi-images"],
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": gen_len,
"no_stop_trim": True, "no_stop_trim": True,
"skip_special_tokens": False, "skip_special_tokens": False,
}, },
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment