Unverified Commit 084fa54d authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

Add support for OpenAI API : offline batch(file) processing (#699)


Co-authored-by: default avatarhnyls2002 <hnyls2002@gmail.com>
parent eba458bd
...@@ -4,6 +4,6 @@ repos: ...@@ -4,6 +4,6 @@ repos:
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: stable rev: 24.4.2
hooks: hooks:
- id: black - id: black
import json
import os
import time
import openai
from openai import OpenAI
class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
self.client = client
def process_batch(self, input_file_path, endpoint, completion_window):
# Upload the input file
with open(input_file_path, "rb") as file:
uploaded_file = self.client.files.create(file=file, purpose="batch")
# Create the batch job
batch_job = self.client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)
# Monitor the batch job status
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3) # Wait for 3 seconds before checking the status again
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = self.client.batches.retrieve(batch_job.id)
# Check the batch job status and errors
if batch_job.status == "failed":
print(f"Batch job failed with status: {batch_job.status}")
print(f"Batch job errors: {batch_job.errors}")
return None
# If the batch job is completed, process the results
if batch_job.status == "completed":
# print result of batch job
print("batch", batch_job.request_counts)
result_file_id = batch_job.output_file_id
# Retrieve the file content from the server
file_response = self.client.files.content(result_file_id)
result_content = file_response.read() # Read the content of the file
# Save the content to a local file
result_file_name = "batch_job_chat_results.jsonl"
with open(result_file_name, "wb") as file:
file.write(result_content) # Write the binary content to the file
# Load data from the saved JSONL file
results = []
with open(result_file_name, "r", encoding="utf-8") as file:
for line in file:
json_object = json.loads(
line.strip()
) # Parse each line as a JSON object
results.append(json_object)
return results
else:
print(f"Batch job failed with status: {batch_job.status}")
return None
# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)
# Process the batch job
input_file_path = "input.jsonl"
endpoint = "/v1/chat/completions"
completion_window = "24h"
# Process the batch job
results = processor.process_batch(input_file_path, endpoint, completion_window)
# Print the results
print(results)
import json
import os
import time
import openai
from openai import OpenAI
class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
self.client = client
def process_batch(self, input_file_path, endpoint, completion_window):
# Upload the input file
with open(input_file_path, "rb") as file:
uploaded_file = self.client.files.create(file=file, purpose="batch")
# Create the batch job
batch_job = self.client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)
# Monitor the batch job status
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3) # Wait for 3 seconds before checking the status again
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = self.client.batches.retrieve(batch_job.id)
# Check the batch job status and errors
if batch_job.status == "failed":
print(f"Batch job failed with status: {batch_job.status}")
print(f"Batch job errors: {batch_job.errors}")
return None
# If the batch job is completed, process the results
if batch_job.status == "completed":
# print result of batch job
print("batch", batch_job.request_counts)
result_file_id = batch_job.output_file_id
# Retrieve the file content from the server
file_response = self.client.files.content(result_file_id)
result_content = file_response.read() # Read the content of the file
# Save the content to a local file
result_file_name = "batch_job_complete_results.jsonl"
with open(result_file_name, "wb") as file:
file.write(result_content) # Write the binary content to the file
# Load data from the saved JSONL file
results = []
with open(result_file_name, "r", encoding="utf-8") as file:
for line in file:
json_object = json.loads(
line.strip()
) # Parse each line as a JSON object
results.append(json_object)
return results
else:
print(f"Batch job failed with status: {batch_job.status}")
return None
# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)
# Process the batch job
input_file_path = "input_complete.jsonl"
endpoint = "/v1/completions"
completion_window = "24h"
# Process the batch job
results = processor.process_batch(input_file_path, endpoint, completion_window)
# Print the results
print(results)
...@@ -13,6 +13,17 @@ response = client.completions.create( ...@@ -13,6 +13,17 @@ response = client.completions.create(
print(response) print(response)
# Text completion
response = client.completions.create(
model="default",
prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little",
n=1,
temperature=0.8,
max_tokens=32,
)
print(response)
# Text completion # Text completion
response = client.completions.create( response = client.completions.create(
model="default", model="default",
...@@ -24,6 +35,17 @@ response = client.completions.create( ...@@ -24,6 +35,17 @@ response = client.completions.create(
print(response) print(response)
# Text completion
response = client.completions.create(
model="default",
prompt=["The name of the famous soccer player is"],
n=1,
temperature=0.8,
max_tokens=128,
)
print(response)
# Text completion # Text completion
response = client.completions.create( response = client.completions.create(
model="default", model="default",
...@@ -60,6 +82,21 @@ response = client.completions.create( ...@@ -60,6 +82,21 @@ response = client.completions.create(
) )
print(response) print(response)
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.8,
max_tokens=64,
logprobs=True,
n=1,
)
print(response)
# Chat completion # Chat completion
response = client.chat.completions.create( response = client.chat.completions.create(
model="default", model="default",
......
...@@ -79,8 +79,26 @@ class GenerateReqInput: ...@@ -79,8 +79,26 @@ class GenerateReqInput:
if self.top_logprobs_num is None: if self.top_logprobs_num is None:
self.top_logprobs_num = 0 self.top_logprobs_num = 0
else: else:
parallel_sample_num_list = []
parallel_sample_num = self.sampling_params.get("n", 1) if isinstance(self.sampling_params, dict):
parallel_sample_num = self.sampling_params.get("n", 1)
elif isinstance(self.sampling_params, list):
for sp in self.sampling_params:
parallel_sample_num = sp.get("n", 1)
parallel_sample_num_list.append(parallel_sample_num)
parallel_sample_num = max(parallel_sample_num_list)
all_equal = all(
element == parallel_sample_num
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
## TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
else:
parallel_sample_num = 1
self.parallel_sample_num = parallel_sample_num
if parallel_sample_num != 1: if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage # parallel sampling +1 represents the original prefill stage
......
...@@ -84,6 +84,7 @@ class TokenizerManager: ...@@ -84,6 +84,7 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args, model_overide_args=model_overide_args,
) )
if server_args.context_length is not None: if server_args.context_length is not None:
self.context_len = server_args.context_length self.context_len = server_args.context_length
else: else:
...@@ -152,31 +153,33 @@ class TokenizerManager: ...@@ -152,31 +153,33 @@ class TokenizerManager:
self, obj, request, index=None, is_cache_for_prefill=False self, obj, request, index=None, is_cache_for_prefill=False
): ):
if not is_cache_for_prefill: if not is_cache_for_prefill:
rid = obj.rid if index is None else obj.rid[index] not_use_index = not (index is not None)
input_text = obj.text if index is None else obj.text[index] rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
input_ids = ( input_ids = (
self.tokenizer.encode(input_text) self.tokenizer.encode(input_text)
if obj.input_ids is None if obj.input_ids is None
else obj.input_ids else obj.input_ids
) )
if index is not None and obj.input_ids: if not not_use_index and obj.input_ids:
input_ids = obj.input_ids[index] input_ids = obj.input_ids[index]
self._validate_input_length(input_ids) self._validate_input_length(input_ids)
sampling_params = self._get_sampling_params( sampling_params = self._get_sampling_params(
obj.sampling_params if index is None else obj.sampling_params[index] obj.sampling_params if not_use_index else obj.sampling_params[index]
) )
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data if index is None else obj.image_data[index] obj.image_data if not_use_index else obj.image_data[index]
) )
return_logprob = ( return_logprob = (
obj.return_logprob if index is None else obj.return_logprob[index] obj.return_logprob if not_use_index else obj.return_logprob[index]
) )
logprob_start_len = ( logprob_start_len = (
obj.logprob_start_len if index is None else obj.logprob_start_len[index] obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
) )
top_logprobs_num = ( top_logprobs_num = (
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index] obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
) )
else: else:
if isinstance(obj.text, list): if isinstance(obj.text, list):
...@@ -224,7 +227,7 @@ class TokenizerManager: ...@@ -224,7 +227,7 @@ class TokenizerManager:
async def _handle_batch_request(self, obj: GenerateReqInput, request): async def _handle_batch_request(self, obj: GenerateReqInput, request):
batch_size = obj.batch_size batch_size = obj.batch_size
parallel_sample_num = obj.sampling_params[0].get("n", 1) parallel_sample_num = obj.parallel_sample_num
if parallel_sample_num != 1: if parallel_sample_num != 1:
# Send prefill requests to cache the common input # Send prefill requests to cache the common input
...@@ -241,7 +244,6 @@ class TokenizerManager: ...@@ -241,7 +244,6 @@ class TokenizerManager:
obj.input_ids = input_id_result obj.input_ids = input_id_result
elif input_id_result is not None: elif input_id_result is not None:
obj.input_ids = input_id_result[0] obj.input_ids = input_id_result[0]
# First send out all requests # First send out all requests
for i in range(batch_size): for i in range(batch_size):
for j in range(parallel_sample_num): for j in range(parallel_sample_num):
...@@ -249,7 +251,7 @@ class TokenizerManager: ...@@ -249,7 +251,7 @@ class TokenizerManager:
continue continue
index = i * parallel_sample_num + j index = i * parallel_sample_num + j
if parallel_sample_num != 1: if parallel_sample_num != 1:
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
index += batch_size - 1 - i index += batch_size - 1 - i
rid = obj.rid[index] rid = obj.rid[index]
if parallel_sample_num == 1: if parallel_sample_num == 1:
......
This diff is collapsed.
...@@ -60,6 +60,55 @@ class UsageInfo(BaseModel): ...@@ -60,6 +60,55 @@ class UsageInfo(BaseModel):
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded
purpose: str = (
"batch" # The intended purpose of the uploaded file, default is "batch"
)
class FileResponse(BaseModel):
id: str
object: str = "file"
bytes: int
created_at: int
filename: str
purpose: str
class BatchRequest(BaseModel):
input_file_id: (
str # The ID of an uploaded file that contains requests for the new batch
)
endpoint: str # The endpoint to be used for all requests in the batch
completion_window: str # The time frame within which the batch should be processed
metadata: Optional[dict] = None # Optional custom metadata for the batch
class BatchResponse(BaseModel):
id: str
object: str = "batch"
endpoint: str
errors: Optional[dict] = None
input_file_id: str
completion_window: str
status: str = "validating"
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
created_at: int
in_progress_at: Optional[int] = None
expires_at: Optional[int] = None
finalizing_at: Optional[int] = None
completed_at: Optional[int] = None
failed_at: Optional[int] = None
expired_at: Optional[int] = None
cancelling_at: Optional[int] = None
cancelled_at: Optional[int] = None
request_counts: dict = {"total": 0, "completed": 0, "failed": 0}
metadata: Optional[dict] = None
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create # https://platform.openai.com/docs/api-reference/completions/create
......
...@@ -38,7 +38,7 @@ import psutil ...@@ -38,7 +38,7 @@ import psutil
import requests import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, Request from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
...@@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput ...@@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import ( from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api, load_chat_template_for_openai_api,
v1_batches,
v1_chat_completions, v1_chat_completions,
v1_completions, v1_completions,
v1_files_create,
v1_retrieve_batch,
v1_retrieve_file,
v1_retrieve_file_content,
) )
from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request): ...@@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request) return await v1_chat_completions(tokenizer_manager, raw_request)
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create(
file, purpose, tokenizer_manager.server_args.file_storage_pth
)
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
return await v1_batches(tokenizer_manager, raw_request)
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
return await v1_retrieve_batch(batch_id)
@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve
return await v1_retrieve_file(file_id)
@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
return await v1_retrieve_file_content(file_id)
@app.get("/v1/models") @app.get("/v1/models")
def available_models(): def available_models():
"""Show available models.""" """Show available models."""
......
...@@ -60,6 +60,7 @@ class ServerArgs: ...@@ -60,6 +60,7 @@ class ServerArgs:
# Other # Other
api_key: str = "" api_key: str = ""
file_storage_pth: str = "SGlang_storage"
# Data parallelism # Data parallelism
dp_size: int = 1 dp_size: int = 1
...@@ -290,6 +291,12 @@ class ServerArgs: ...@@ -290,6 +291,12 @@ class ServerArgs:
default=ServerArgs.api_key, default=ServerArgs.api_key,
help="Set API key of the server.", help="Set API key of the server.",
) )
parser.add_argument(
"--file-storage-pth",
type=str,
default=ServerArgs.file_storage_pth,
help="The path of the file storage in backend.",
)
# Data parallelism # Data parallelism
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment