Unverified Commit 2f1d9283 authored by caiyueliang's avatar caiyueliang Committed by GitHub
Browse files

[FEAT] Support batches cancel (#1222)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent c61a1b6f
...@@ -275,10 +275,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ...@@ -275,10 +275,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
end_point = batch_storage[batch_id].endpoint end_point = batch_storage[batch_id].endpoint
file_request_list = [] file_request_list = []
all_requests = [] all_requests = []
request_ids = []
for line in lines: for line in lines:
request_data = json.loads(line) request_data = json.loads(line)
file_request_list.append(request_data) file_request_list.append(request_data)
body = request_data["body"] body = request_data["body"]
request_ids.append(request_data["custom_id"])
# Although streaming is supported for standalone completions, it is not supported in # Although streaming is supported for standalone completions, it is not supported in
# batch mode (multiple completions in single request). # batch mode (multiple completions in single request).
...@@ -289,12 +291,16 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ...@@ -289,12 +291,16 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
all_requests.append(ChatCompletionRequest(**body)) all_requests.append(ChatCompletionRequest(**body))
elif end_point == "/v1/completions": elif end_point == "/v1/completions":
all_requests.append(CompletionRequest(**body)) all_requests.append(CompletionRequest(**body))
if end_point == "/v1/chat/completions": if end_point == "/v1/chat/completions":
adapted_request, request = v1_chat_generate_request( adapted_request, request = v1_chat_generate_request(
all_requests, tokenizer_manager all_requests, tokenizer_manager, request_ids=request_ids
) )
elif end_point == "/v1/completions": elif end_point == "/v1/completions":
adapted_request, request = v1_generate_request(all_requests) adapted_request, request = v1_generate_request(
all_requests, request_ids=request_ids
)
try: try:
ret = await tokenizer_manager.generate_request(adapted_request).__anext__() ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
if not isinstance(ret, list): if not isinstance(ret, list):
...@@ -326,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ...@@ -326,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
} }
all_ret.append(response_json) all_ret.append(response_json)
completed_requests += 1 completed_requests += 1
# Write results to a new file # Write results to a new file
output_file_id = f"backend_result_file-{uuid.uuid4()}" output_file_id = f"backend_result_file-{uuid.uuid4()}"
global storage_dir global storage_dir
...@@ -372,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str): ...@@ -372,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str):
return batch_response return batch_response
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
# Retrieve the batch job from the in-memory storage
batch_response = batch_storage.get(batch_id)
if batch_response is None:
raise HTTPException(status_code=404, detail="Batch not found")
# Only do cancal when status is "validating" or "in_progress"
if batch_response.status in ["validating", "in_progress"]:
# Start cancelling the batch asynchronously
asyncio.create_task(
cancel_batch(
tokenizer_manager=tokenizer_manager,
batch_id=batch_id,
input_file_id=batch_response.input_file_id,
)
)
# Update batch status to "cancelling"
batch_response.status = "cancelling"
return batch_response
else:
raise HTTPException(
status_code=500,
detail=f"Current status is {batch_response.status}, no need to cancel",
)
async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
try:
# Update the batch status to "cancelling"
batch_storage[batch_id].status = "cancelling"
# Retrieve the input file content
input_file_request = file_id_request.get(input_file_id)
if not input_file_request:
raise ValueError("Input file not found")
# Parse the JSONL file and process each request
input_file_path = file_id_storage.get(input_file_id)
with open(input_file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
file_request_list = []
request_ids = []
for line in lines:
request_data = json.loads(line)
file_request_list.append(request_data)
request_ids.append(request_data["custom_id"])
# Cancel requests by request_ids
for rid in request_ids:
tokenizer_manager.abort_request(rid=rid)
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "cancelled"
except Exception as e:
logger.error("error in SGLang:", e)
# Update batch status to "failed"
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed"
retrieve_batch.failed_at = int(time.time())
retrieve_batch.errors = {"message": str(e)}
async def v1_retrieve_file(file_id: str): async def v1_retrieve_file(file_id: str):
# Retrieve the batch job from the in-memory storage # Retrieve the batch job from the in-memory storage
file_response = file_id_response.get(file_id) file_response = file_id_response.get(file_id)
...@@ -392,7 +465,9 @@ async def v1_retrieve_file_content(file_id: str): ...@@ -392,7 +465,9 @@ async def v1_retrieve_file_content(file_id: str):
return StreamingResponse(iter_file(), media_type="application/octet-stream") return StreamingResponse(iter_file(), media_type="application/octet-stream")
def v1_generate_request(all_requests: List[CompletionRequest]): def v1_generate_request(
all_requests: List[CompletionRequest], request_ids: List[str] = None
):
prompts = [] prompts = []
sampling_params_list = [] sampling_params_list = []
return_logprobs = [] return_logprobs = []
...@@ -464,6 +539,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]): ...@@ -464,6 +539,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
logprob_start_len=logprob_start_lens, logprob_start_len=logprob_start_lens,
return_text_in_logprobs=True, return_text_in_logprobs=True,
stream=all_requests[0].stream, stream=all_requests[0].stream,
rid=request_ids,
) )
if len(all_requests) == 1: if len(all_requests) == 1:
...@@ -746,7 +822,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -746,7 +822,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
def v1_chat_generate_request( def v1_chat_generate_request(
all_requests: List[ChatCompletionRequest], tokenizer_manager all_requests: List[ChatCompletionRequest],
tokenizer_manager,
request_ids: List[str] = None,
): ):
input_ids = [] input_ids = []
sampling_params_list = [] sampling_params_list = []
...@@ -834,6 +912,7 @@ def v1_chat_generate_request( ...@@ -834,6 +912,7 @@ def v1_chat_generate_request(
top_logprobs_num=top_logprobs_nums, top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream, stream=all_requests[0].stream,
return_text_in_logprobs=True, return_text_in_logprobs=True,
rid=request_ids,
) )
if len(all_requests) == 1: if len(all_requests) == 1:
return adapted_request, all_requests[0] return adapted_request, all_requests[0]
......
...@@ -59,6 +59,7 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager ...@@ -59,6 +59,7 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import ( from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api, load_chat_template_for_openai_api,
v1_batches, v1_batches,
v1_cancel_batch,
v1_chat_completions, v1_chat_completions,
v1_completions, v1_completions,
v1_delete_file, v1_delete_file,
...@@ -246,6 +247,12 @@ async def openai_v1_batches(raw_request: Request): ...@@ -246,6 +247,12 @@ async def openai_v1_batches(raw_request: Request):
return await v1_batches(tokenizer_manager, raw_request) return await v1_batches(tokenizer_manager, raw_request)
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
# https://platform.openai.com/docs/api-reference/batch/cancel
return await v1_cancel_batch(tokenizer_manager, batch_id)
@app.get("/v1/batches/{batch_id}") @app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str): async def retrieve_batch(batch_id: str):
return await v1_retrieve_batch(batch_id) return await v1_retrieve_batch(batch_id)
......
...@@ -256,8 +256,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -256,8 +256,7 @@ class TestOpenAIServer(unittest.TestCase):
index, True index, True
), f"index {index} is not found in the response" ), f"index {index} is not found in the response"
def run_batch(self, mode): def _create_batch(self, mode, client):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
if mode == "completion": if mode == "completion":
input_file_path = "complete_input.jsonl" input_file_path = "complete_input.jsonl"
# write content to input file # write content to input file
...@@ -333,9 +332,11 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -333,9 +332,11 @@ class TestOpenAIServer(unittest.TestCase):
}, },
}, },
] ]
with open(input_file_path, "w") as file: with open(input_file_path, "w") as file:
for line in content: for line in content:
file.write(json.dumps(line) + "\n") file.write(json.dumps(line) + "\n")
with open(input_file_path, "rb") as file: with open(input_file_path, "rb") as file:
uploaded_file = client.files.create(file=file, purpose="batch") uploaded_file = client.files.create(file=file, purpose="batch")
if mode == "completion": if mode == "completion":
...@@ -348,6 +349,13 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -348,6 +349,13 @@ class TestOpenAIServer(unittest.TestCase):
endpoint=endpoint, endpoint=endpoint,
completion_window=completion_window, completion_window=completion_window,
) )
return batch_job, content
def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, content = self._create_batch(mode=mode, client=client)
while batch_job.status not in ["completed", "failed", "cancelled"]: while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3) time.sleep(3)
print( print(
...@@ -371,6 +379,24 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -371,6 +379,24 @@ class TestOpenAIServer(unittest.TestCase):
] ]
assert len(results) == len(content) assert len(results) == len(content)
def run_cancel_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, _ = self._create_batch(mode=mode, client=client)
assert batch_job.status not in ["cancelling", "cancelled"]
batch_job = client.batches.cancel(batch_id=batch_job.id)
assert batch_job.status == "cancelling"
while batch_job.status not in ["failed", "cancelled"]:
batch_job = client.batches.retrieve(batch_job.id)
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
time.sleep(3)
assert batch_job.status == "cancelled"
def test_completion(self): def test_completion(self):
for echo in [False, True]: for echo in [False, True]:
for logprobs in [None, 5]: for logprobs in [None, 5]:
...@@ -414,6 +440,10 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -414,6 +440,10 @@ class TestOpenAIServer(unittest.TestCase):
for mode in ["completion", "chat"]: for mode in ["completion", "chat"]:
self.run_batch(mode) self.run_batch(mode)
def test_calcel_batch(self):
for mode in ["completion", "chat"]:
self.run_cancel_batch(mode)
def test_regex(self): def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment