"docker/vscode:/vscode.git/clone" did not exist on "ecb9fa14e6baaedf255bfbc065e2a58a76efb1df"
Unverified Commit 72676cd6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat(oai refactor): Replace `openai_api` with `entrypoints/openai` (#7351)


Co-authored-by: default avatarJin Pan <jpan236@wisc.edu>
parent 02bf31ef
...@@ -235,6 +235,7 @@ class TestOpenAIServer(CustomTestCase): ...@@ -235,6 +235,7 @@ class TestOpenAIServer(CustomTestCase):
) )
is_firsts = {} is_firsts = {}
is_finished = {}
for response in generator: for response in generator:
usage = response.usage usage = response.usage
if usage is not None: if usage is not None:
...@@ -244,6 +245,10 @@ class TestOpenAIServer(CustomTestCase): ...@@ -244,6 +245,10 @@ class TestOpenAIServer(CustomTestCase):
continue continue
index = response.choices[0].index index = response.choices[0].index
finish_reason = response.choices[0].finish_reason
if finish_reason is not None:
is_finished[index] = True
data = response.choices[0].delta data = response.choices[0].delta
if is_firsts.get(index, True): if is_firsts.get(index, True):
...@@ -253,7 +258,7 @@ class TestOpenAIServer(CustomTestCase): ...@@ -253,7 +258,7 @@ class TestOpenAIServer(CustomTestCase):
is_firsts[index] = False is_firsts[index] = False
continue continue
if logprobs: if logprobs and not is_finished.get(index, False):
assert response.choices[0].logprobs, f"logprobs was not returned" assert response.choices[0].logprobs, f"logprobs was not returned"
assert isinstance( assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str response.choices[0].logprobs.content[0].top_logprobs[0].token, str
...@@ -271,7 +276,7 @@ class TestOpenAIServer(CustomTestCase): ...@@ -271,7 +276,7 @@ class TestOpenAIServer(CustomTestCase):
assert ( assert (
isinstance(data.content, str) isinstance(data.content, str)
or isinstance(data.reasoning_content, str) or isinstance(data.reasoning_content, str)
or len(data.tool_calls) > 0 or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
or response.choices[0].finish_reason or response.choices[0].finish_reason
) )
assert response.id assert response.id
...@@ -282,152 +287,6 @@ class TestOpenAIServer(CustomTestCase): ...@@ -282,152 +287,6 @@ class TestOpenAIServer(CustomTestCase):
index, True index, True
), f"index {index} is not found in the response" ), f"index {index} is not found in the response"
def _create_batch(self, mode, client):
if mode == "completion":
input_file_path = "complete_input.jsonl"
# write content to input file
content = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "gpt-3.5-turbo-instruct",
"prompt": "List 3 names of famous soccer player: ",
"max_tokens": 20,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "gpt-3.5-turbo-instruct",
"prompt": "List 6 names of famous basketball player: ",
"max_tokens": 40,
},
},
{
"custom_id": "request-3",
"method": "POST",
"url": "/v1/completions",
"body": {
"model": "gpt-3.5-turbo-instruct",
"prompt": "List 6 names of famous tenniss player: ",
"max_tokens": 40,
},
},
]
else:
input_file_path = "chat_input.jsonl"
content = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo-0125",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": "Hello! List 3 NBA players and tell a story",
},
],
"max_tokens": 30,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo-0125",
"messages": [
{"role": "system", "content": "You are an assistant. "},
{
"role": "user",
"content": "Hello! List three capital and tell a story",
},
],
"max_tokens": 50,
},
},
]
with open(input_file_path, "w") as file:
for line in content:
file.write(json.dumps(line) + "\n")
with open(input_file_path, "rb") as file:
uploaded_file = client.files.create(file=file, purpose="batch")
if mode == "completion":
endpoint = "/v1/completions"
elif mode == "chat":
endpoint = "/v1/chat/completions"
completion_window = "24h"
batch_job = client.batches.create(
input_file_id=uploaded_file.id,
endpoint=endpoint,
completion_window=completion_window,
)
return batch_job, content, uploaded_file
def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client)
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3)
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = client.batches.retrieve(batch_job.id)
assert (
batch_job.status == "completed"
), f"Batch job status is not completed: {batch_job.status}"
assert batch_job.request_counts.completed == len(content)
assert batch_job.request_counts.failed == 0
assert batch_job.request_counts.total == len(content)
result_file_id = batch_job.output_file_id
file_response = client.files.content(result_file_id)
result_content = file_response.read().decode("utf-8") # Decode bytes to string
results = [
json.loads(line)
for line in result_content.split("\n")
if line.strip() != ""
]
assert len(results) == len(content)
for delete_fid in [uploaded_file.id, result_file_id]:
del_pesponse = client.files.delete(delete_fid)
assert del_pesponse.deleted
def run_cancel_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client)
assert batch_job.status not in ["cancelling", "cancelled"]
batch_job = client.batches.cancel(batch_id=batch_job.id)
assert batch_job.status == "cancelling"
while batch_job.status not in ["failed", "cancelled"]:
batch_job = client.batches.retrieve(batch_job.id)
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
time.sleep(3)
assert batch_job.status == "cancelled"
del_response = client.files.delete(uploaded_file.id)
assert del_response.deleted
def test_completion(self): def test_completion(self):
for echo in [False, True]: for echo in [False, True]:
for logprobs in [None, 5]: for logprobs in [None, 5]:
...@@ -467,14 +326,6 @@ class TestOpenAIServer(CustomTestCase): ...@@ -467,14 +326,6 @@ class TestOpenAIServer(CustomTestCase):
for parallel_sample_num in [1, 2]: for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(logprobs, parallel_sample_num) self.run_chat_completion_stream(logprobs, parallel_sample_num)
def test_batch(self):
for mode in ["completion", "chat"]:
self.run_batch(mode)
def test_cancel_batch(self):
for mode in ["completion", "chat"]:
self.run_cancel_batch(mode)
def test_regex(self): def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
...@@ -559,6 +410,18 @@ The SmartHome Mini is a compact smart home assistant available in black or white ...@@ -559,6 +410,18 @@ The SmartHome Mini is a compact smart home assistant available in black or white
assert len(models) == 1 assert len(models) == 1
assert isinstance(getattr(models[0], "max_model_len", None), int) assert isinstance(getattr(models[0], "max_model_len", None), int)
def test_retrieve_model(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# Test retrieving an existing model
retrieved_model = client.models.retrieve(self.model)
self.assertEqual(retrieved_model.id, self.model)
self.assertEqual(retrieved_model.root, self.model)
# Test retrieving a non-existent model
with self.assertRaises(openai.NotFoundError):
client.models.retrieve("non-existent-model")
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF # EBNF Test Class: TestOpenAIServerEBNF
...@@ -684,6 +547,31 @@ class TestOpenAIEmbedding(CustomTestCase): ...@@ -684,6 +547,31 @@ class TestOpenAIEmbedding(CustomTestCase):
self.assertTrue(len(response.data[0].embedding) > 0) self.assertTrue(len(response.data[0].embedding) > 0)
self.assertTrue(len(response.data[1].embedding) > 0) self.assertTrue(len(response.data[1].embedding) > 0)
def test_embedding_single_batch_str(self):
"""Test embedding with a List[str] and length equals to 1"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(model=self.model, input=["Hello world"])
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_embedding_single_int_list(self):
"""Test embedding with a List[int] or List[List[int]]]"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
)
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
)
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_empty_string_embedding(self): def test_empty_string_embedding(self):
"""Test embedding an empty string.""" """Test embedding an empty string."""
......
...@@ -21,6 +21,7 @@ from transformers import ( ...@@ -21,6 +21,7 @@ from transformers import (
from sglang import Engine from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
...@@ -31,7 +32,6 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -31,7 +32,6 @@ from sglang.srt.managers.schedule_batch import (
MultimodalInputs, MultimodalInputs,
) )
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
......
...@@ -15,7 +15,7 @@ from transformers import ( ...@@ -15,7 +15,7 @@ from transformers import (
from sglang import Engine from sglang import Engine
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment