Unverified Commit e1e595d7 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[feat] Refactor session control interface and add CI (#2173)

parent 5ada33ff
...@@ -173,7 +173,6 @@ class DetokenizerManager: ...@@ -173,7 +173,6 @@ class DetokenizerManager:
output_strs=output_strs, output_strs=output_strs,
meta_info=recv_obj.meta_info, meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason, finished_reason=recv_obj.finished_reason,
session_ids=recv_obj.session_ids,
) )
) )
......
...@@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). ...@@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -55,8 +55,9 @@ class GenerateReqInput: ...@@ -55,8 +55,9 @@ class GenerateReqInput:
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session id info for continual prompting # Session id info for continual prompting
session_id: Optional[Union[List[str], str]] = None session: Optional[
session_rid: Optional[Union[List[str], str]] = None Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
...@@ -203,7 +204,7 @@ class TokenizedGenerateReqInput: ...@@ -203,7 +204,7 @@ class TokenizedGenerateReqInput:
lora_path: Optional[str] = None # None means just use the base model lora_path: Optional[str] = None # None means just use the base model
# Session id info for continual prompting # Session id info for continual prompting
session_id: Optional[int] = None session_id: Optional[str] = None
session_rid: Optional[str] = None session_rid: Optional[str] = None
...@@ -299,8 +300,6 @@ class BatchTokenIDOut: ...@@ -299,8 +300,6 @@ class BatchTokenIDOut:
meta_info: List[Dict] meta_info: List[Dict]
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
no_stop_trim: List[bool] no_stop_trim: List[bool]
# The updated session unique id
session_ids: List[str]
@dataclass @dataclass
...@@ -313,8 +312,6 @@ class BatchStrOut: ...@@ -313,8 +312,6 @@ class BatchStrOut:
meta_info: List[Dict] meta_info: List[Dict]
# The finish reason # The finish reason
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
# The update session unique id
session_ids: List[str]
@dataclass @dataclass
......
...@@ -542,9 +542,7 @@ class Scheduler: ...@@ -542,9 +542,7 @@ class Scheduler:
else: else:
# Handle sessions # Handle sessions
session = self.sessions[recv_req.session_id] session = self.sessions[recv_req.session_id]
req, new_session_id = session.create_req(recv_req, self.tokenizer) req = session.create_req(recv_req, self.tokenizer)
del self.sessions[recv_req.session_id]
self.sessions[new_session_id] = session
if isinstance(req.finished_reason, FINISH_ABORT): if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
...@@ -1188,7 +1186,6 @@ class Scheduler: ...@@ -1188,7 +1186,6 @@ class Scheduler:
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_no_stop_trim = [] output_no_stop_trim = []
output_session_ids = []
else: # embedding or reward model else: # embedding or reward model
output_embeddings = [] output_embeddings = []
...@@ -1216,7 +1213,6 @@ class Scheduler: ...@@ -1216,7 +1213,6 @@ class Scheduler:
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
) )
output_no_stop_trim.append(req.sampling_params.no_stop_trim) output_no_stop_trim.append(req.sampling_params.no_stop_trim)
output_session_ids.append(req.session_id)
meta_info = { meta_info = {
"prompt_tokens": len(req.origin_input_ids), "prompt_tokens": len(req.origin_input_ids),
...@@ -1267,7 +1263,6 @@ class Scheduler: ...@@ -1267,7 +1263,6 @@ class Scheduler:
output_meta_info, output_meta_info,
output_finished_reason, output_finished_reason,
output_no_stop_trim, output_no_stop_trim,
output_session_ids,
) )
) )
else: # embedding or reward model else: # embedding or reward model
......
...@@ -26,13 +26,13 @@ class Session: ...@@ -26,13 +26,13 @@ class Session:
self.reqs: List[Req] = [] self.reqs: List[Req] = []
def create_req(self, req: TokenizedGenerateReqInput, tokenizer): def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
# renew session id
self.session_id = uuid.uuid4().hex
if req.session_rid is not None: if req.session_rid is not None:
while len(self.reqs) > 0: while len(self.reqs) > 0:
if self.reqs[-1].rid == req.session_rid: if self.reqs[-1].rid == req.session_rid:
break break
self.reqs = self.reqs[:-1] self.reqs = self.reqs[:-1]
else:
self.reqs = []
if len(self.reqs) > 0: if len(self.reqs) > 0:
input_ids = ( input_ids = (
self.reqs[-1].origin_input_ids self.reqs[-1].origin_input_ids
...@@ -58,4 +58,4 @@ class Session: ...@@ -58,4 +58,4 @@ class Session:
) )
else: else:
self.reqs.append(new_req) self.reqs.append(new_req)
return new_req, self.session_id return new_req
...@@ -216,8 +216,8 @@ class TokenizerManager: ...@@ -216,8 +216,8 @@ class TokenizerManager:
return_logprob = obj.return_logprob return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num top_logprobs_num = obj.top_logprobs_num
session_id = obj.session_id session_id = obj.session[0] if obj.session else None
session_rid = obj.session_rid session_rid = obj.session[1] if obj.session else None
if len(input_ids) >= self.context_len: if len(input_ids) >= self.context_len:
raise ValueError( raise ValueError(
...@@ -570,13 +570,11 @@ class TokenizerManager: ...@@ -570,13 +570,11 @@ class TokenizerManager:
out_dict = { out_dict = {
"text": recv_obj.output_strs[i], "text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i], "meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
} }
elif isinstance(recv_obj, BatchTokenIDOut): elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = { out_dict = {
"token_ids": recv_obj.output_ids[i], "token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i], "meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
} }
else: else:
assert isinstance(recv_obj, BatchEmbeddingOut) assert isinstance(recv_obj, BatchEmbeddingOut)
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# FIXME: Make it a CI test
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
url = "http://localhost:30000"
# Open a session
response = requests.post(
url + "/open_session",
json={"capacity_of_str_len": 1000},
)
session_id = response.json()
print("session_id", session_id, "\n")
# Prefill only
prompt = "chunk 1"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 0,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
# Generate
prompt = "Chunk 2"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
rid = response.json()["meta_info"]["id"]
# Generate
prompt = "Chunk 3"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 2,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
rid_to_del = response.json()["meta_info"]["id"]
# Interrupt and re-generate
prompt = "Chunk 4"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"session_rid": rid,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
session_id = response.json()["session_id"]
# Query a session based on a deleted request, should see finish reason abort
prompt = "Chunk 4"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"session_rid": rid_to_del,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
print(response.json(), "\n")
# Close session
ret = requests.post(
url + "/close_session",
json={"session_id": session_id},
)
print(ret, "\n")
# Query a deleted session, should see finish reason abort
prompt = "chunk 1"
response = requests.post(
url + "/generate",
json={
"text": prompt,
"session_id": session_id,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 0,
},
},
)
print(response.json(), "\n")
...@@ -34,6 +34,7 @@ suites = { ...@@ -34,6 +34,7 @@ suites = {
"test_triton_attention_backend.py", "test_triton_attention_backend.py",
"test_update_weights.py", "test_update_weights.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_session_control.py",
], ],
"sampling/penaltylib": glob.glob( "sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True "sampling/penaltylib/**/test_*.py", recursive=True
......
"""
Usage:
python3 -m unittest test_session_control.TestSessionControl.test_session_control
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm
"""
import unittest
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestSessionControl(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)
def test_session_control(self):
chunks = [
"Let me tell you something about France.",
"The capital of France is",
"A brief history about that city is",
"To plan a travel, the budget is",
]
tokenizer = get_tokenizer(self.model)
chunks_ids = [tokenizer.encode(x) for x in chunks]
# 1. using session control
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
).json()
rid = None
first_rid = None
outputs_from_session = []
for i, chunk_ids in enumerate(chunks_ids):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunk_ids,
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
},
},
).json()
rid = response["meta_info"]["id"]
if i == 0:
first_rid = rid
if i > 0:
outputs_from_session.append(response["text"])
# backtrack to the first request and regenerate
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunks_ids[-1],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
outputs_from_session.append(response["text"])
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunks_ids[-1],
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
ret = requests.post(
self.base_url + "/close_session",
json={"session_id": session_id},
)
assert ret.status_code == 200
# send a request to a closed session, should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunks_ids[-1],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
# 2. not use session control
input_ids_first_req = None
input_ids = []
outputs_normal = []
for i, chunk_ids in enumerate(chunks_ids):
input_ids += chunk_ids
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
},
},
).json()
if i > 0:
input_ids += tokenizer.encode(response["text"])[
1:
] # drop the bos token
outputs_normal.append(response["text"])
if i == 0:
input_ids_first_req = input_ids.copy()
input_ids_first_req += chunks_ids[-1]
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids_first_req,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
).json()
outputs_normal.append(response["text"])
print("outputs from chunked queries with session control:")
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment