"vscode:/vscode.git/clone" did not exist on "96ebfa652b342c327a335cda7e101959a39d6ef1"
Unverified Commit 084fa54d authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

Add support for OpenAI API : offline batch(file) processing (#699)


Co-authored-by: default avatarhnyls2002 <hnyls2002@gmail.com>
parent eba458bd
......@@ -4,6 +4,6 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: stable
rev: 24.4.2
hooks:
- id: black
import json
import os
import time
import openai
from openai import OpenAI
class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
self.client = client
def process_batch(self, input_file_path, endpoint, completion_window):
# Upload the input file
with open(input_file_path, "rb") as file:
uploaded_file = self.client.files.create(file=file, purpose="batch")
# Create the batch job
batch_job = self.client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)
# Monitor the batch job status
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3) # Wait for 3 seconds before checking the status again
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = self.client.batches.retrieve(batch_job.id)
# Check the batch job status and errors
if batch_job.status == "failed":
print(f"Batch job failed with status: {batch_job.status}")
print(f"Batch job errors: {batch_job.errors}")
return None
# If the batch job is completed, process the results
if batch_job.status == "completed":
# print result of batch job
print("batch", batch_job.request_counts)
result_file_id = batch_job.output_file_id
# Retrieve the file content from the server
file_response = self.client.files.content(result_file_id)
result_content = file_response.read() # Read the content of the file
# Save the content to a local file
result_file_name = "batch_job_chat_results.jsonl"
with open(result_file_name, "wb") as file:
file.write(result_content) # Write the binary content to the file
# Load data from the saved JSONL file
results = []
with open(result_file_name, "r", encoding="utf-8") as file:
for line in file:
json_object = json.loads(
line.strip()
) # Parse each line as a JSON object
results.append(json_object)
return results
else:
print(f"Batch job failed with status: {batch_job.status}")
return None
# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)
# Process the batch job
input_file_path = "input.jsonl"
endpoint = "/v1/chat/completions"
completion_window = "24h"
# Process the batch job
results = processor.process_batch(input_file_path, endpoint, completion_window)
# Print the results
print(results)
import json
import os
import time
import openai
from openai import OpenAI
class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
self.client = client
def process_batch(self, input_file_path, endpoint, completion_window):
# Upload the input file
with open(input_file_path, "rb") as file:
uploaded_file = self.client.files.create(file=file, purpose="batch")
# Create the batch job
batch_job = self.client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)
# Monitor the batch job status
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3) # Wait for 3 seconds before checking the status again
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = self.client.batches.retrieve(batch_job.id)
# Check the batch job status and errors
if batch_job.status == "failed":
print(f"Batch job failed with status: {batch_job.status}")
print(f"Batch job errors: {batch_job.errors}")
return None
# If the batch job is completed, process the results
if batch_job.status == "completed":
# print result of batch job
print("batch", batch_job.request_counts)
result_file_id = batch_job.output_file_id
# Retrieve the file content from the server
file_response = self.client.files.content(result_file_id)
result_content = file_response.read() # Read the content of the file
# Save the content to a local file
result_file_name = "batch_job_complete_results.jsonl"
with open(result_file_name, "wb") as file:
file.write(result_content) # Write the binary content to the file
# Load data from the saved JSONL file
results = []
with open(result_file_name, "r", encoding="utf-8") as file:
for line in file:
json_object = json.loads(
line.strip()
) # Parse each line as a JSON object
results.append(json_object)
return results
else:
print(f"Batch job failed with status: {batch_job.status}")
return None
# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)
# Process the batch job
input_file_path = "input_complete.jsonl"
endpoint = "/v1/completions"
completion_window = "24h"
# Process the batch job
results = processor.process_batch(input_file_path, endpoint, completion_window)
# Print the results
print(results)
......@@ -13,6 +13,17 @@ response = client.completions.create(
print(response)
# Text completion
response = client.completions.create(
model="default",
prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little",
n=1,
temperature=0.8,
max_tokens=32,
)
print(response)
# Text completion
response = client.completions.create(
model="default",
......@@ -24,6 +35,17 @@ response = client.completions.create(
print(response)
# Text completion
response = client.completions.create(
model="default",
prompt=["The name of the famous soccer player is"],
n=1,
temperature=0.8,
max_tokens=128,
)
print(response)
# Text completion
response = client.completions.create(
model="default",
......@@ -60,6 +82,21 @@ response = client.completions.create(
)
print(response)
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=64,
logprobs=True,
n=1,
)
print(response)
# Chat completion
response = client.chat.completions.create(
model="default",
......
......@@ -79,8 +79,26 @@ class GenerateReqInput:
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
else:
parallel_sample_num = self.sampling_params.get("n", 1)
parallel_sample_num_list = []
if isinstance(self.sampling_params, dict):
parallel_sample_num = self.sampling_params.get("n", 1)
elif isinstance(self.sampling_params, list):
for sp in self.sampling_params:
parallel_sample_num = sp.get("n", 1)
parallel_sample_num_list.append(parallel_sample_num)
parallel_sample_num = max(parallel_sample_num_list)
all_equal = all(
element == parallel_sample_num
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
## TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
else:
parallel_sample_num = 1
self.parallel_sample_num = parallel_sample_num
if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
......
......@@ -84,6 +84,7 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)
if server_args.context_length is not None:
self.context_len = server_args.context_length
else:
......@@ -152,31 +153,33 @@ class TokenizerManager:
self, obj, request, index=None, is_cache_for_prefill=False
):
if not is_cache_for_prefill:
rid = obj.rid if index is None else obj.rid[index]
input_text = obj.text if index is None else obj.text[index]
not_use_index = not (index is not None)
rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
input_ids = (
self.tokenizer.encode(input_text)
if obj.input_ids is None
else obj.input_ids
)
if index is not None and obj.input_ids:
if not not_use_index and obj.input_ids:
input_ids = obj.input_ids[index]
self._validate_input_length(input_ids)
sampling_params = self._get_sampling_params(
obj.sampling_params if index is None else obj.sampling_params[index]
obj.sampling_params if not_use_index else obj.sampling_params[index]
)
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if index is None else obj.image_data[index]
obj.image_data if not_use_index else obj.image_data[index]
)
return_logprob = (
obj.return_logprob if index is None else obj.return_logprob[index]
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
logprob_start_len = (
obj.logprob_start_len if index is None else obj.logprob_start_len[index]
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
)
top_logprobs_num = (
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
)
else:
if isinstance(obj.text, list):
......@@ -224,7 +227,7 @@ class TokenizerManager:
async def _handle_batch_request(self, obj: GenerateReqInput, request):
batch_size = obj.batch_size
parallel_sample_num = obj.sampling_params[0].get("n", 1)
parallel_sample_num = obj.parallel_sample_num
if parallel_sample_num != 1:
# Send prefill requests to cache the common input
......@@ -241,7 +244,6 @@ class TokenizerManager:
obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]
# First send out all requests
for i in range(batch_size):
for j in range(parallel_sample_num):
......@@ -249,7 +251,7 @@ class TokenizerManager:
continue
index = i * parallel_sample_num + j
if parallel_sample_num != 1:
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index += batch_size - 1 - i
rid = obj.rid[index]
if parallel_sample_num == 1:
......
......@@ -18,10 +18,14 @@ limitations under the License.
import asyncio
import json
import os
import time
import uuid
from http import HTTPStatus
from typing import Dict, List, Optional
from fastapi import Request
from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import ValidationError
from sglang.srt.conversation import (
Conversation,
......@@ -32,6 +36,8 @@ from sglang.srt.conversation import (
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.openai_api.protocol import (
BatchRequest,
BatchResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
......@@ -45,6 +51,8 @@ from sglang.srt.openai_api.protocol import (
CompletionStreamResponse,
DeltaMessage,
ErrorResponse,
FileRequest,
FileResponse,
LogProbs,
UsageInfo,
)
......@@ -52,6 +60,24 @@ from sglang.srt.openai_api.protocol import (
chat_template_name = None
class FileMetadata:
def __init__(self, filename: str, purpose: str):
self.filename = filename
self.purpose = purpose
# In-memory storage for batch jobs and files
batch_storage: Dict[str, BatchResponse] = {}
file_id_request: Dict[str, FileMetadata] = {}
file_id_response: Dict[str, FileResponse] = {}
## map file id to file path in SGlang backend
file_id_storage: Dict[str, str] = {}
# backend storage directory
storage_dir = None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
......@@ -106,33 +132,364 @@ def load_chat_template_for_openai_api(chat_template_arg):
chat_template_name = chat_template_arg
async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = CompletionRequest(**request_json)
prompt = request.prompt
if isinstance(prompt, str) or isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt}
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
try:
global storage_dir
if file_storage_pth:
storage_dir = file_storage_pth
# Read the file content
file_content = await file.read()
# Create an instance of RequestBody
request_body = FileRequest(file=file_content, purpose=purpose)
# Save the file to the sglang_oai_storage directory
os.makedirs(storage_dir, exist_ok=True)
file_id = f"backend_input_file-{uuid.uuid4()}"
filename = f"{file_id}.jsonl"
file_path = os.path.join(storage_dir, filename)
with open(file_path, "wb") as f:
f.write(request_body.file)
# add info to global file map
file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
file_id_storage[file_id] = file_path
# Return the response in the required format
response = FileResponse(
id=file_id,
bytes=len(request_body.file),
created_at=int(time.time()),
filename=file.filename,
purpose=request_body.purpose,
)
file_id_response[file_id] = response
return response
except ValidationError as e:
return {"error": "Invalid input", "details": e.errors()}
async def v1_batches(tokenizer_manager, raw_request: Request):
try:
body = await raw_request.json()
batch_request = BatchRequest(**body)
batch_id = f"batch_{uuid.uuid4()}"
# Create an instance of BatchResponse
batch_response = BatchResponse(
id=batch_id,
endpoint=batch_request.endpoint,
input_file_id=batch_request.input_file_id,
completion_window=batch_request.completion_window,
created_at=int(time.time()),
metadata=batch_request.metadata,
)
batch_storage[batch_id] = batch_response
# Start processing the batch asynchronously
asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
# Return the initial batch_response
return batch_response
except ValidationError as e:
return {"error": "Invalid input", "details": e.errors()}
except Exception as e:
return {"error": str(e)}
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
try:
# Update the batch status to "in_progress"
batch_storage[batch_id].status = "in_progress"
batch_storage[batch_id].in_progress_at = int(time.time())
# Retrieve the input file content
input_file_request = file_id_request.get(batch_request.input_file_id)
if not input_file_request:
raise ValueError("Input file not found")
# Parse the JSONL file and process each request
input_file_path = file_id_storage.get(batch_request.input_file_id)
with open(input_file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
total_requests = len(lines)
completed_requests = 0
failed_requests = 0
all_ret = []
end_point = batch_storage[batch_id].endpoint
file_request_list = []
all_requests = []
for line in lines:
request_data = json.loads(line)
file_request_list.append(request_data)
body = request_data["body"]
if end_point == "/v1/chat/completions":
all_requests.append(ChatCompletionRequest(**body))
elif end_point == "/v1/completions":
all_requests.append(CompletionRequest(**body))
if end_point == "/v1/chat/completions":
adapted_request, request = v1_chat_generate_request(
all_requests, tokenizer_manager
)
elif end_point == "/v1/completions":
adapted_request, request = v1_generate_request(all_requests)
try:
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
if not isinstance(ret, list):
ret = [ret]
if end_point == "/v1/chat/completions":
responses = v1_chat_generate_response(request, ret, to_file=True)
else:
responses = v1_generate_response(request, ret, to_file=True)
except Exception as e:
error_json = {
"id": f"batch_req_{uuid.uuid4()}",
"custom_id": request_data.get("custom_id"),
"response": None,
"error": {"message": str(e)},
}
all_ret.append(error_json)
failed_requests += len(file_request_list)
for idx, response in enumerate(responses):
## the batch_req here can be changed to be named within a batch granularity
response_json = {
"id": f"batch_req_{uuid.uuid4()}",
"custom_id": file_request_list[idx].get("custom_id"),
"response": response,
"error": None,
}
all_ret.append(response_json)
completed_requests += 1
# Write results to a new file
output_file_id = f"backend_result_file-{uuid.uuid4()}"
global storage_dir
output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
with open(output_file_path, "w", encoding="utf-8") as f:
for ret in all_ret:
f.write(json.dumps(ret) + "\n")
# Update batch response with output file information
retrieve_batch = batch_storage[batch_id]
retrieve_batch.output_file_id = output_file_id
file_id_storage[output_file_id] = output_file_path
# Update batch status to "completed"
retrieve_batch.status = "completed"
retrieve_batch.completed_at = int(time.time())
retrieve_batch.request_counts = {
"total": total_requests,
"completed": completed_requests,
"failed": failed_requests,
}
except Exception as e:
print("error in SGlang:", e)
# Update batch status to "failed"
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed"
retrieve_batch.failed_at = int(time.time())
retrieve_batch.errors = {"message": str(e)}
async def v1_retrieve_batch(batch_id: str):
# Retrieve the batch job from the in-memory storage
batch_response = batch_storage.get(batch_id)
if batch_response is None:
raise HTTPException(status_code=404, detail="Batch not found")
return batch_response
async def v1_retrieve_file(file_id: str):
# Retrieve the batch job from the in-memory storage
file_response = file_id_response.get(file_id)
if file_response is None:
raise HTTPException(status_code=404, detail="File not found")
return file_response
async def v1_retrieve_file_content(file_id: str):
file_pth = file_id_storage.get(file_id)
if not file_pth or not os.path.exists(file_pth):
raise HTTPException(status_code=404, detail="File not found")
def iter_file():
with open(file_pth, mode="rb") as file_like:
yield from file_like
return StreamingResponse(iter_file(), media_type="application/octet-stream")
def v1_generate_request(all_requests):
prompts = []
sampling_params_list = []
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
prompt = request.prompt
assert (
type(prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings"
prompts.append(prompt)
sampling_params_list.append(
{
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": request.stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
"n": request.n,
"ignore_eos": request.ignore_eos,
}
)
if len(all_requests) > 1 and request.n > 1:
raise ValueError(
"Batch operation is not supported for completions from files"
)
if len(all_requests) == 1:
prompt = prompts[0]
sampling_params_list = sampling_params_list[0]
if isinstance(prompts, str) or isinstance(prompts[0], str):
prompt_kwargs = {"text": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
if isinstance(prompts[0], str):
prompt_kwargs = {"text": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": request.stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
"n": request.n,
"ignore_eos": request.ignore_eos,
},
return_logprob=request.logprobs is not None and request.logprobs > 0,
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
sampling_params=sampling_params_list,
return_logprob=all_requests[0].logprobs is not None
and all_requests[0].logprobs > 0,
top_logprobs_num=(
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
),
return_text_in_logprobs=True,
stream=request.stream,
stream=all_requests[0].stream,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]
return adapted_request, all_requests
def v1_generate_response(request, ret, to_file=False):
choices = []
echo = False
if (not isinstance(request, List)) and request.echo:
# TODO: handle the case propmt is token ids
if isinstance(request.prompt, list):
prompts = request.prompt
else:
prompts = [request.prompt]
echo = True
for idx, ret_item in enumerate(ret):
text = ret_item["text"]
if isinstance(request, List) and request[idx].echo:
echo = True
text = request[idx].prompt + text
if (not isinstance(request, List)) and echo:
text = prompts[idx] + text
logprobs = False
if isinstance(request, List) and request[idx].logprobs:
logprobs = True
elif (not isinstance(request, List)) and request.logprobs:
logprobs = True
if logprobs:
if echo:
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
else:
logprobs = None
if to_file:
## to make the choise data json serializable
choice_data = {
"index": 0,
"text": text,
"logprobs": logprobs,
"finish_reason": ret_item["meta_info"]["finish_reason"],
}
else:
choice_data = CompletionResponseChoice(
index=idx,
text=text,
logprobs=logprobs,
finish_reason=ret_item["meta_info"]["finish_reason"],
)
choices.append(choice_data)
if to_file:
responses = []
for i, choice in enumerate(choices):
response = {
"status_code": 200,
"request_id": ret[i]["meta_info"]["id"],
"body": {
## remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"],
"object": "text_completion",
"created": int(time.time()),
"model": request[i].model,
"choices": choice,
"usage": {
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
+ ret[i]["meta_info"]["completion_tokens"],
},
"system_fingerprint": None,
},
}
responses.append(response)
return responses
else:
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = CompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
completion_tokens=completion_tokens,
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
),
)
return response
async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
all_requests = [CompletionRequest(**request_json)]
adapted_request, request = v1_generate_request(all_requests)
if adapted_request.stream:
......@@ -223,109 +580,144 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]
if request.echo:
# TODO: handle the case propmt is token ids
if isinstance(request.prompt, list):
prompts = request.prompt
else:
prompts = [request.prompt]
choices = []
for idx, ret_item in enumerate(ret):
text = ret_item["text"]
response = v1_generate_response(request, ret)
return response
if request.echo:
text = prompts[idx] + text
if request.logprobs:
if request.echo:
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
def v1_chat_generate_request(all_requests, tokenizer_manager):
texts = []
sampling_params_list = []
image_data_list = []
for request in all_requests:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
image_data = None
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
image_data = conv.image_data
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
else:
stop.extend(request.stop)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=idx,
text=text,
logprobs=logprobs,
finish_reason=ret_item["meta_info"]["finish_reason"],
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
image_data = None
texts.append(prompt)
sampling_params_list.append(
{
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
"n": request.n,
}
)
choices.append(choice_data)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = CompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
completion_tokens=completion_tokens,
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
),
image_data_list.append(image_data)
if len(all_requests) == 1:
texts = texts[0]
sampling_params_list = sampling_params_list[0]
image_data = image_data_list[0]
adapted_request = GenerateReqInput(
text=texts,
image_data=image_data,
sampling_params=sampling_params_list,
stream=request.stream,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]
return adapted_request, all_requests
return response
def v1_chat_generate_response(request, ret, to_file=False):
choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json)
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
image_data = None
for idx, ret_item in enumerate(ret):
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]
if to_file:
## to make the choice data json serializable
choice_data = {
"index": 0,
"message": {"role": "assistant", "content": ret_item["text"]},
"logprobs": None,
"finish_reason": ret_item["meta_info"]["finish_reason"],
}
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
image_data = conv.image_data
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
else:
stop.extend(request.stop)
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]),
finish_reason=ret_item["meta_info"]["finish_reason"],
)
choices.append(choice_data)
total_prompt_tokens = prompt_tokens
total_completion_tokens += completion_tokens
if to_file:
responses = []
for i, choice in enumerate(choices):
response = {
"status_code": 200,
"request_id": ret[i]["meta_info"]["id"],
"body": {
## remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"],
"object": "chat.completion",
"created": int(time.time()),
"model": request[i].model,
"choices": choice,
"usage": {
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
+ ret[i]["meta_info"]["completion_tokens"],
},
"system_fingerprint": None,
},
}
responses.append(response)
return responses
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
image_data = None
response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
),
)
return response
adapted_request = GenerateReqInput(
text=prompt,
image_data=image_data,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
"n": request.n,
},
stream=request.stream,
)
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
all_requests = [ChatCompletionRequest(**request_json)]
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
if adapted_request.stream:
......@@ -387,34 +779,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]
choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
for idx, ret_item in enumerate(ret):
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]),
finish_reason=ret_item["meta_info"]["finish_reason"],
)
choices.append(choice_data)
total_prompt_tokens = prompt_tokens
total_completion_tokens += completion_tokens
response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
),
)
response = v1_chat_generate_response(request, ret)
return response
......
......@@ -60,6 +60,55 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0
class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
purpose: str = (
"batch" # The intended purpose of the uploaded file, default is "batch"
)
class FileResponse(BaseModel):
id: str
object: str = "file"
bytes: int
created_at: int
filename: str
purpose: str
class BatchRequest(BaseModel):
input_file_id: (
str # The ID of an uploaded file that contains requests for the new batch
)
endpoint: str # The endpoint to be used for all requests in the batch
completion_window: str # The time frame within which the batch should be processed
metadata: Optional[dict] = None # Optional custom metadata for the batch
class BatchResponse(BaseModel):
id: str
object: str = "batch"
endpoint: str
errors: Optional[dict] = None
input_file_id: str
completion_window: str
status: str = "validating"
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
created_at: int
in_progress_at: Optional[int] = None
expires_at: Optional[int] = None
finalizing_at: Optional[int] = None
completed_at: Optional[int] = None
failed_at: Optional[int] = None
expired_at: Optional[int] = None
cancelling_at: Optional[int] = None
cancelled_at: Optional[int] = None
request_counts: dict = {"total": 0, "completed": 0, "failed": 0}
metadata: Optional[dict] = None
class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
......
......@@ -38,7 +38,7 @@ import psutil
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
......@@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
v1_batches,
v1_chat_completions,
v1_completions,
v1_files_create,
v1_retrieve_batch,
v1_retrieve_file,
v1_retrieve_file_content,
)
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.server_args import PortArgs, ServerArgs
......@@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request)
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create(
file, purpose, tokenizer_manager.server_args.file_storage_pth
)
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
return await v1_batches(tokenizer_manager, raw_request)
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
return await v1_retrieve_batch(batch_id)
@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve
return await v1_retrieve_file(file_id)
@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
return await v1_retrieve_file_content(file_id)
@app.get("/v1/models")
def available_models():
"""Show available models."""
......
......@@ -60,6 +60,7 @@ class ServerArgs:
# Other
api_key: str = ""
file_storage_pth: str = "SGlang_storage"
# Data parallelism
dp_size: int = 1
......@@ -290,6 +291,12 @@ class ServerArgs:
default=ServerArgs.api_key,
help="Set API key of the server.",
)
parser.add_argument(
"--file-storage-pth",
type=str,
default=ServerArgs.file_storage_pth,
help="The path of the file storage in backend.",
)
# Data parallelism
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment