fastapi_engine_inference.py 6.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
FastAPI server example for text generation using SGLang Engine and demonstrating client usage.

Starts the server, sends requests to it, and prints responses.

Usage:
python fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000
"""

import os
import subprocess
import time
from contextlib import asynccontextmanager

import requests
from fastapi import FastAPI, Request

import sglang as sgl
from sglang.utils import terminate_process

engine = None


# Use FastAPI's lifespan manager to initialize/shutdown the engine
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Manages SGLang engine initialization during server startup."""
    global engine
    # Initialize the SGLang engine when the server starts
    # Adjust model_path and other engine arguments as needed
    print("Loading SGLang engine...")
    engine = sgl.Engine(
        model_path=os.getenv("MODEL_PATH"), tp_size=int(os.getenv("TP_SIZE"))
    )
    print("SGLang engine loaded.")
    yield
    # Clean up engine resources when the server stops (optional, depends on engine needs)
    print("Shutting down SGLang engine...")
    # engine.shutdown() # Or other cleanup if available/necessary
    print("SGLang engine shutdown.")


app = FastAPI(lifespan=lifespan)


@app.post("/generate")
async def generate_text(request: Request):
    """FastAPI endpoint to handle text generation requests."""
    global engine
    if not engine:
        return {"error": "Engine not initialized"}, 503

    try:
        data = await request.json()
        prompt = data.get("prompt")
        max_new_tokens = data.get("max_new_tokens", 128)
        temperature = data.get("temperature", 0.7)

        if not prompt:
            return {"error": "Prompt is required"}, 400

        # Use async_generate for non-blocking generation
        state = await engine.async_generate(
            prompt,
            sampling_params={
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
            },
            # Add other parameters like stop, top_p etc. as needed
        )

        return {"generated_text": state["text"]}
    except Exception as e:
        return {"error": str(e)}, 500


# Helper function to start the server
def start_server(args, timeout=60):
    """Starts the Uvicorn server as a subprocess and waits for it to be ready."""
    base_url = f"http://{args.host}:{args.port}"
    command = [
        "python",
        "-m",
        "uvicorn",
        "fastapi_engine_inference:app",
        f"--host={args.host}",
        f"--port={args.port}",
    ]

    process = subprocess.Popen(command, stdout=None, stderr=None)

    start_time = time.time()
    with requests.Session() as session:
        while time.time() - start_time < timeout:
            try:
                # Check the /docs endpoint which FastAPI provides by default
                response = session.get(
                    f"{base_url}/docs", timeout=5
                )  # Add a request timeout
                if response.status_code == 200:
                    print(f"Server {base_url} is ready (responded on /docs)")
                    return process
            except requests.ConnectionError:
                # Specific exception for connection refused/DNS error etc.
                pass
            except requests.Timeout:
                # Specific exception for request timeout
                print(f"Health check to {base_url}/docs timed out, retrying...")
                pass
            except requests.RequestException as e:
                # Catch other request exceptions
                print(f"Health check request error: {e}, retrying...")
                pass
            # Use a shorter sleep interval for faster startup detection
            time.sleep(1)

    # If loop finishes, raise the timeout error
    # Attempt to terminate the failed process before raising
    if process:
        print(
            "Server failed to start within timeout, attempting to terminate process..."
        )
        terminate_process(process)  # Use the imported terminate_process
    raise TimeoutError(
        f"Server failed to start at {base_url} within the timeout period."
    )


def send_requests(server_url, prompts, max_new_tokens, temperature):
    """Sends generation requests to the running server for a list of prompts."""
    # Iterate through prompts and send requests
    for i, prompt in enumerate(prompts):
        print(f"\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'")
        payload = {
            "prompt": prompt,
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
        }

        try:
            response = requests.post(f"{server_url}/generate", json=payload, timeout=60)

            result = response.json()

            print(f"Prompt: {prompt}\nResponse: {result['generated_text']}")

        except requests.exceptions.Timeout:
            print(f"  Error: Request timed out for prompt '{prompt}'")
        except requests.exceptions.RequestException as e:
            print(f"  Error sending request for prompt '{prompt}': {e}")


if __name__ == "__main__":
    """Main entry point for the script."""

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="127.0.0.1")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-0.5B-Instruct")
    parser.add_argument("--tp_size", type=int, default=1)
    args = parser.parse_args()

    # Pass the model to the child uvicorn process via an env var
    os.environ["MODEL_PATH"] = args.model_path
    os.environ["TP_SIZE"] = str(args.tp_size)

    # Start the server
    process = start_server(args)

    # Define the prompts and sampling parameters
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    max_new_tokens = 64
    temperature = 0.1

    # Define server url
    server_url = f"http://{args.host}:{args.port}"

    # Send requests to the server
    send_requests(server_url, prompts, max_new_tokens, temperature)

    # Terminate the server process
    terminate_process(process)