Unverified Commit d1112d85 authored by Rin Intachuen's avatar Rin Intachuen Committed by GitHub
Browse files

Add endpoint for file support, purely to speed up processing of input_embeds. (#2797)

parent 48efec7b
...@@ -19,6 +19,7 @@ This file implements HTTP APIs for the inference engine via fastapi. ...@@ -19,6 +19,7 @@ This file implements HTTP APIs for the inference engine via fastapi.
import asyncio import asyncio
import dataclasses import dataclasses
import json
import logging import logging
import multiprocessing as multiprocessing import multiprocessing as multiprocessing
import os import os
...@@ -259,6 +260,29 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -259,6 +260,29 @@ async def generate_request(obj: GenerateReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
@app.api_route("/generate_from_file", methods=["POST"])
async def generate_from_file_request(file: UploadFile, request: Request):
"""Handle a generate request, this is purely to work with input_embeds."""
content = await file.read()
input_embeds = json.loads(content.decode("utf-8"))
obj = GenerateReqInput(
input_embeds=input_embeds,
sampling_params={
"repetition_penalty": 1.2,
"temperature": 0.2,
"max_new_tokens": 512,
},
)
try:
ret = await _global_state.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return _create_error_response(e)
@app.api_route("/encode", methods=["POST", "PUT"]) @app.api_route("/encode", methods=["POST", "PUT"])
async def encode_request(obj: EmbeddingReqInput, request: Request): async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request.""" """Handle an embedding request."""
......
import json import json
import os
import tempfile
import unittest import unittest
import requests import requests
...@@ -38,7 +40,7 @@ class TestInputEmbeds(unittest.TestCase): ...@@ -38,7 +40,7 @@ class TestInputEmbeds(unittest.TestCase):
return embeddings.squeeze().tolist() # Convert tensor to a list for API use return embeddings.squeeze().tolist() # Convert tensor to a list for API use
def send_request(self, payload): def send_request(self, payload):
"""Send a POST request to the API and return the response.""" """Send a POST request to the /generate endpoint and return the response."""
response = requests.post( response = requests.post(
self.base_url + "/generate", self.base_url + "/generate",
json=payload, json=payload,
...@@ -50,8 +52,22 @@ class TestInputEmbeds(unittest.TestCase): ...@@ -50,8 +52,22 @@ class TestInputEmbeds(unittest.TestCase):
"error": f"Request failed with status {response.status_code}: {response.text}" "error": f"Request failed with status {response.status_code}: {response.text}"
} }
def send_file_request(self, file_path):
"""Send a POST request to the /generate_from_file endpoint with a file."""
with open(file_path, "rb") as f:
response = requests.post(
self.base_url + "/generate_from_file",
files={"file": f},
timeout=30, # Set a reasonable timeout for the API request
)
if response.status_code == 200:
return response.json()
return {
"error": f"Request failed with status {response.status_code}: {response.text}"
}
def test_text_based_response(self): def test_text_based_response(self):
"""Print API response using text-based input.""" """Test and print API responses using text-based input."""
for text in self.texts: for text in self.texts:
payload = { payload = {
"model": self.model, "model": self.model,
...@@ -64,7 +80,7 @@ class TestInputEmbeds(unittest.TestCase): ...@@ -64,7 +80,7 @@ class TestInputEmbeds(unittest.TestCase):
) )
def test_embedding_based_response(self): def test_embedding_based_response(self):
"""Print API response using input embeddings.""" """Test and print API responses using input embeddings."""
for text in self.texts: for text in self.texts:
embeddings = self.generate_input_embeddings(text) embeddings = self.generate_input_embeddings(text)
payload = { payload = {
...@@ -78,7 +94,7 @@ class TestInputEmbeds(unittest.TestCase): ...@@ -78,7 +94,7 @@ class TestInputEmbeds(unittest.TestCase):
) )
def test_compare_text_vs_embedding(self): def test_compare_text_vs_embedding(self):
"""Print responses for both text-based and embedding-based inputs.""" """Test and compare responses for text-based and embedding-based inputs."""
for text in self.texts: for text in self.texts:
# Text-based payload # Text-based payload
text_payload = { text_payload = {
...@@ -106,6 +122,25 @@ class TestInputEmbeds(unittest.TestCase): ...@@ -106,6 +122,25 @@ class TestInputEmbeds(unittest.TestCase):
# This is flaky, so we skip this temporarily # This is flaky, so we skip this temporarily
# self.assertEqual(text_response["text"], embed_response["text"]) # self.assertEqual(text_response["text"], embed_response["text"])
def test_generate_from_file(self):
"""Test the /generate_from_file endpoint using tokenized embeddings."""
for text in self.texts:
embeddings = self.generate_input_embeddings(text)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
json.dump(embeddings, tmp_file)
tmp_file_path = tmp_file.name
try:
response = self.send_file_request(tmp_file_path)
print(
f"Text Input: {text}\nResponse from /generate_from_file: {json.dumps(response, indent=2)}\n{'-' * 80}"
)
finally:
# Ensure the temporary file is deleted
os.remove(tmp_file_path)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment