Unverified Commit 084fa54d authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

Add support for OpenAI API : offline batch(file) processing (#699)


Co-authored-by: default avatarhnyls2002 <hnyls2002@gmail.com>
parent eba458bd
......@@ -4,6 +4,6 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: stable
rev: 24.4.2
hooks:
- id: black
import json
import os
import time
import openai
from openai import OpenAI
class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
self.client = client
def process_batch(self, input_file_path, endpoint, completion_window):
# Upload the input file
with open(input_file_path, "rb") as file:
uploaded_file = self.client.files.create(file=file, purpose="batch")
# Create the batch job
batch_job = self.client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)
# Monitor the batch job status
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3) # Wait for 3 seconds before checking the status again
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = self.client.batches.retrieve(batch_job.id)
# Check the batch job status and errors
if batch_job.status == "failed":
print(f"Batch job failed with status: {batch_job.status}")
print(f"Batch job errors: {batch_job.errors}")
return None
# If the batch job is completed, process the results
if batch_job.status == "completed":
# print result of batch job
print("batch", batch_job.request_counts)
result_file_id = batch_job.output_file_id
# Retrieve the file content from the server
file_response = self.client.files.content(result_file_id)
result_content = file_response.read() # Read the content of the file
# Save the content to a local file
result_file_name = "batch_job_chat_results.jsonl"
with open(result_file_name, "wb") as file:
file.write(result_content) # Write the binary content to the file
# Load data from the saved JSONL file
results = []
with open(result_file_name, "r", encoding="utf-8") as file:
for line in file:
json_object = json.loads(
line.strip()
) # Parse each line as a JSON object
results.append(json_object)
return results
else:
print(f"Batch job failed with status: {batch_job.status}")
return None
# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)
# Process the batch job
input_file_path = "input.jsonl"
endpoint = "/v1/chat/completions"
completion_window = "24h"
# Process the batch job
results = processor.process_batch(input_file_path, endpoint, completion_window)
# Print the results
print(results)
import json
import os
import time
import openai
from openai import OpenAI
class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
self.client = client
def process_batch(self, input_file_path, endpoint, completion_window):
# Upload the input file
with open(input_file_path, "rb") as file:
uploaded_file = self.client.files.create(file=file, purpose="batch")
# Create the batch job
batch_job = self.client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)
# Monitor the batch job status
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3) # Wait for 3 seconds before checking the status again
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = self.client.batches.retrieve(batch_job.id)
# Check the batch job status and errors
if batch_job.status == "failed":
print(f"Batch job failed with status: {batch_job.status}")
print(f"Batch job errors: {batch_job.errors}")
return None
# If the batch job is completed, process the results
if batch_job.status == "completed":
# print result of batch job
print("batch", batch_job.request_counts)
result_file_id = batch_job.output_file_id
# Retrieve the file content from the server
file_response = self.client.files.content(result_file_id)
result_content = file_response.read() # Read the content of the file
# Save the content to a local file
result_file_name = "batch_job_complete_results.jsonl"
with open(result_file_name, "wb") as file:
file.write(result_content) # Write the binary content to the file
# Load data from the saved JSONL file
results = []
with open(result_file_name, "r", encoding="utf-8") as file:
for line in file:
json_object = json.loads(
line.strip()
) # Parse each line as a JSON object
results.append(json_object)
return results
else:
print(f"Batch job failed with status: {batch_job.status}")
return None
# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)
# Process the batch job
input_file_path = "input_complete.jsonl"
endpoint = "/v1/completions"
completion_window = "24h"
# Process the batch job
results = processor.process_batch(input_file_path, endpoint, completion_window)
# Print the results
print(results)
......@@ -13,6 +13,17 @@ response = client.completions.create(
print(response)
# Text completion
response = client.completions.create(
model="default",
prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little",
n=1,
temperature=0.8,
max_tokens=32,
)
print(response)
# Text completion
response = client.completions.create(
model="default",
......@@ -24,6 +35,17 @@ response = client.completions.create(
print(response)
# Text completion
response = client.completions.create(
model="default",
prompt=["The name of the famous soccer player is"],
n=1,
temperature=0.8,
max_tokens=128,
)
print(response)
# Text completion
response = client.completions.create(
model="default",
......@@ -60,6 +82,21 @@ response = client.completions.create(
)
print(response)
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=64,
logprobs=True,
n=1,
)
print(response)
# Chat completion
response = client.chat.completions.create(
model="default",
......
......@@ -79,8 +79,26 @@ class GenerateReqInput:
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
else:
parallel_sample_num = self.sampling_params.get("n", 1)
parallel_sample_num_list = []
if isinstance(self.sampling_params, dict):
parallel_sample_num = self.sampling_params.get("n", 1)
elif isinstance(self.sampling_params, list):
for sp in self.sampling_params:
parallel_sample_num = sp.get("n", 1)
parallel_sample_num_list.append(parallel_sample_num)
parallel_sample_num = max(parallel_sample_num_list)
all_equal = all(
element == parallel_sample_num
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
## TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
else:
parallel_sample_num = 1
self.parallel_sample_num = parallel_sample_num
if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
......
......@@ -84,6 +84,7 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)
if server_args.context_length is not None:
self.context_len = server_args.context_length
else:
......@@ -152,31 +153,33 @@ class TokenizerManager:
self, obj, request, index=None, is_cache_for_prefill=False
):
if not is_cache_for_prefill:
rid = obj.rid if index is None else obj.rid[index]
input_text = obj.text if index is None else obj.text[index]
not_use_index = not (index is not None)
rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
input_ids = (
self.tokenizer.encode(input_text)
if obj.input_ids is None
else obj.input_ids
)
if index is not None and obj.input_ids:
if not not_use_index and obj.input_ids:
input_ids = obj.input_ids[index]
self._validate_input_length(input_ids)
sampling_params = self._get_sampling_params(
obj.sampling_params if index is None else obj.sampling_params[index]
obj.sampling_params if not_use_index else obj.sampling_params[index]
)
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if index is None else obj.image_data[index]
obj.image_data if not_use_index else obj.image_data[index]
)
return_logprob = (
obj.return_logprob if index is None else obj.return_logprob[index]
obj.return_logprob if not_use_index else obj.return_logprob[index]
)
logprob_start_len = (
obj.logprob_start_len if index is None else obj.logprob_start_len[index]
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
)
top_logprobs_num = (
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
)
else:
if isinstance(obj.text, list):
......@@ -224,7 +227,7 @@ class TokenizerManager:
async def _handle_batch_request(self, obj: GenerateReqInput, request):
batch_size = obj.batch_size
parallel_sample_num = obj.sampling_params[0].get("n", 1)
parallel_sample_num = obj.parallel_sample_num
if parallel_sample_num != 1:
# Send prefill requests to cache the common input
......@@ -241,7 +244,6 @@ class TokenizerManager:
obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]
# First send out all requests
for i in range(batch_size):
for j in range(parallel_sample_num):
......@@ -249,7 +251,7 @@ class TokenizerManager:
continue
index = i * parallel_sample_num + j
if parallel_sample_num != 1:
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index += batch_size - 1 - i
rid = obj.rid[index]
if parallel_sample_num == 1:
......
This diff is collapsed.
......@@ -60,6 +60,55 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0
class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
purpose: str = (
"batch" # The intended purpose of the uploaded file, default is "batch"
)
class FileResponse(BaseModel):
id: str
object: str = "file"
bytes: int
created_at: int
filename: str
purpose: str
class BatchRequest(BaseModel):
input_file_id: (
str # The ID of an uploaded file that contains requests for the new batch
)
endpoint: str # The endpoint to be used for all requests in the batch
completion_window: str # The time frame within which the batch should be processed
metadata: Optional[dict] = None # Optional custom metadata for the batch
class BatchResponse(BaseModel):
id: str
object: str = "batch"
endpoint: str
errors: Optional[dict] = None
input_file_id: str
completion_window: str
status: str = "validating"
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
created_at: int
in_progress_at: Optional[int] = None
expires_at: Optional[int] = None
finalizing_at: Optional[int] = None
completed_at: Optional[int] = None
failed_at: Optional[int] = None
expired_at: Optional[int] = None
cancelling_at: Optional[int] = None
cancelled_at: Optional[int] = None
request_counts: dict = {"total": 0, "completed": 0, "failed": 0}
metadata: Optional[dict] = None
class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
......
......@@ -38,7 +38,7 @@ import psutil
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
......@@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
v1_batches,
v1_chat_completions,
v1_completions,
v1_files_create,
v1_retrieve_batch,
v1_retrieve_file,
v1_retrieve_file_content,
)
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.server_args import PortArgs, ServerArgs
......@@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request)
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create(
file, purpose, tokenizer_manager.server_args.file_storage_pth
)
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
return await v1_batches(tokenizer_manager, raw_request)
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
return await v1_retrieve_batch(batch_id)
@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve
return await v1_retrieve_file(file_id)
@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
return await v1_retrieve_file_content(file_id)
@app.get("/v1/models")
def available_models():
"""Show available models."""
......
......@@ -60,6 +60,7 @@ class ServerArgs:
# Other
api_key: str = ""
file_storage_pth: str = "SGlang_storage"
# Data parallelism
dp_size: int = 1
......@@ -290,6 +291,12 @@ class ServerArgs:
default=ServerArgs.api_key,
help="Set API key of the server.",
)
parser.add_argument(
"--file-storage-pth",
type=str,
default=ServerArgs.file_storage_pth,
help="The path of the file storage in backend.",
)
# Data parallelism
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment