Unverified Commit e1e595d7 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[feat] Refactor session control interface and add CI (#2173)

parent 5ada33ff
......@@ -173,7 +173,6 @@ class DetokenizerManager:
output_strs=output_strs,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
session_ids=recv_obj.session_ids,
)
)
......
......@@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import uuid
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
......@@ -55,8 +55,9 @@ class GenerateReqInput:
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session id info for continual prompting
session_id: Optional[Union[List[str], str]] = None
session_rid: Optional[Union[List[str], str]] = None
session: Optional[
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
] = None
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
......@@ -203,7 +204,7 @@ class TokenizedGenerateReqInput:
lora_path: Optional[str] = None # None means just use the base model
# Session id info for continual prompting
session_id: Optional[int] = None
session_id: Optional[str] = None
session_rid: Optional[str] = None
......@@ -299,8 +300,6 @@ class BatchTokenIDOut:
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
no_stop_trim: List[bool]
# The updated session unique id
session_ids: List[str]
@dataclass
......@@ -313,8 +312,6 @@ class BatchStrOut:
meta_info: List[Dict]
# The finish reason
finished_reason: List[BaseFinishReason]
# The update session unique id
session_ids: List[str]
@dataclass
......
......@@ -542,9 +542,7 @@ class Scheduler:
else:
# Handle sessions
session = self.sessions[recv_req.session_id]
req, new_session_id = session.create_req(recv_req, self.tokenizer)
del self.sessions[recv_req.session_id]
self.sessions[new_session_id] = session
req = session.create_req(recv_req, self.tokenizer)
if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req)
return
......@@ -1188,7 +1186,6 @@ class Scheduler:
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_stop_trim = []
output_session_ids = []
else: # embedding or reward model
output_embeddings = []
......@@ -1216,7 +1213,6 @@ class Scheduler:
req.sampling_params.spaces_between_special_tokens
)
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
output_session_ids.append(req.session_id)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
......@@ -1267,7 +1263,6 @@ class Scheduler:
output_meta_info,
output_finished_reason,
output_no_stop_trim,
output_session_ids,
)
)
else: # embedding or reward model
......
......@@ -26,13 +26,13 @@ class Session:
self.reqs: List[Req] = []
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
# renew session id
self.session_id = uuid.uuid4().hex
if req.session_rid is not None:
while len(self.reqs) > 0:
if self.reqs[-1].rid == req.session_rid:
break
self.reqs = self.reqs[:-1]
else:
self.reqs = []
if len(self.reqs) > 0:
input_ids = (
self.reqs[-1].origin_input_ids
......@@ -58,4 +58,4 @@ class Session:
)
else:
self.reqs.append(new_req)
return new_req, self.session_id
return new_req
......@@ -216,8 +216,8 @@ class TokenizerManager:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
session_id = obj.session_id
session_rid = obj.session_rid
session_id = obj.session[0] if obj.session else None
session_rid = obj.session[1] if obj.session else None
if len(input_ids) >= self.context_len:
raise ValueError(
......@@ -570,13 +570,11 @@ class TokenizerManager:
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# FIXME: Make it a CI test
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
url = "http://localhost:30000"
# Open a session
response = requests.post(
url + "/open_session",
json={"capacity_of_str_len": 1000},
)
session_id = response.json()
print("session_id", session_id, "\n")
# Prefill only
prompt = "chunk 1"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 0,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
# Generate
prompt = "Chunk 2"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
rid = response.json()["meta_info"]["id"]
# Generate
prompt = "Chunk 3"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 2,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
rid_to_del = response.json()["meta_info"]["id"]
# Interrupt and re-generate
prompt = "Chunk 4"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"session_rid": rid,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
# Query a session based on a deleted request, should see finish reason abort
prompt = "Chunk 4"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"session_rid": rid_to_del,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
# Close session
ret = requests.post(
url + "/close_session",
json={"session_id": session_id},
)
print(ret, "\n")
# Query a deleted session, should see finish reason abort
prompt = "chunk 1"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 0,
},
},
)
print(response.json(), "\n")
......@@ -34,6 +34,7 @@ suites = {
"test_triton_attention_backend.py",
"test_update_weights.py",
"test_vision_openai_server.py",
"test_session_control.py",
],
"sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True
......
"""
Usage:
python3 -m unittest test_session_control.TestSessionControl.test_session_control
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm
"""
import unittest
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestSessionControl(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)
def test_session_control(self):
chunks = [
"Let me tell you something about France.",
"The capital of France is",
"A brief history about that city is",
"To plan a travel, the budget is",
]
tokenizer = get_tokenizer(self.model)
chunks_ids = [tokenizer.encode(x) for x in chunks]
# 1. using session control
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
).json()
rid = None
first_rid = None
outputs_from_session = []
for i, chunk_ids in enumerate(chunks_ids):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunk_ids,
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
},
},
).json()
rid = response["meta_info"]["id"]
if i == 0:
first_rid = rid
if i > 0:
outputs_from_session.append(response["text"])
# backtrack to the first request and regenerate
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunks_ids[-1],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
outputs_from_session.append(response["text"])
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunks_ids[-1],
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
ret = requests.post(
self.base_url + "/close_session",
json={"session_id": session_id},
)
assert ret.status_code == 200
# send a request to a closed session, should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunks_ids[-1],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
# 2. not use session control
input_ids_first_req = None
input_ids = []
outputs_normal = []
for i, chunk_ids in enumerate(chunks_ids):
input_ids += chunk_ids
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
},
},
).json()
if i > 0:
input_ids += tokenizer.encode(response["text"])[
1:
] # drop the bos token
outputs_normal.append(response["text"])
if i == 0:
input_ids_first_req = input_ids.copy()
input_ids_first_req += chunks_ids[-1]
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids_first_req,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
outputs_normal.append(response["text"])
print("outputs from chunked queries with session control:")
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment