launch_server.py 4.71 KB
Newer Older
Pan Zezhong's avatar
Pan Zezhong committed
1
import asyncio
Pan Zezhong's avatar
Pan Zezhong committed
2
3
from jiuge import JiugeForCauslLM
from libinfinicore_infer import DeviceType
Pan Zezhong's avatar
Pan Zezhong committed
4
5
from infer_task import InferTask
from kvcache_pool import KVCachePool
Pan Zezhong's avatar
Pan Zezhong committed
6
7
8

from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
9
import anyio
Pan Zezhong's avatar
Pan Zezhong committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import uvicorn
import time
import uuid
import sys
import signal
import json

if len(sys.argv) < 3:
    print(
        "Usage: python launch_server.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
    )
    sys.exit(1)
model_path = sys.argv[2]
device_type = DeviceType.DEVICE_TYPE_CPU
if sys.argv[1] == "--cpu":
    device_type = DeviceType.DEVICE_TYPE_CPU
elif sys.argv[1] == "--nvidia":
    device_type = DeviceType.DEVICE_TYPE_NVIDIA
elif sys.argv[1] == "--cambricon":
    device_type = DeviceType.DEVICE_TYPE_CAMBRICON
elif sys.argv[1] == "--ascend":
    device_type = DeviceType.DEVICE_TYPE_ASCEND
elif sys.argv[1] == "--metax":
    device_type = DeviceType.DEVICE_TYPE_METAX
elif sys.argv[1] == "--moore":
    device_type = DeviceType.DEVICE_TYPE_MOORE
else:
    print(
        "Usage: python launch_server.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
    )
    sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1

Pan Zezhong's avatar
Pan Zezhong committed
43
MODEL = JiugeForCauslLM(model_path, device_type, ndev)
Pan Zezhong's avatar
Pan Zezhong committed
44

Pan Zezhong's avatar
Pan Zezhong committed
45
App = FastAPI()
Pan Zezhong's avatar
Pan Zezhong committed
46

Pan Zezhong's avatar
Pan Zezhong committed
47
48
49
50
51
52
53
@App.on_event("startup")
async def setup():
    App.state.kv_cache_pool = KVCachePool(MODEL, 1)

async def handle_shutdown():
    await App.state.kv_cache_pool.finalize()
    MODEL.destroy_model_instance()
Pan Zezhong's avatar
Pan Zezhong committed
54
55
    sys.exit(0)

Pan Zezhong's avatar
Pan Zezhong committed
56
57
58
def signal_handler(sig, frame):
    print(f"Received signal {sig}, cleaning up...")
    asyncio.create_task(handle_shutdown())
Pan Zezhong's avatar
Pan Zezhong committed
59
60
61
62

signal.signal(signal.SIGINT, signal_handler)  # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler)  # Handle docker stop / system shutdown

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

def chunk_json(id_, content=None, role=None, finish_reason=None):
    delta = {}
    if content:
        delta["content"] = content
    if role:
        delta["role"] = role
    return {
        "id": id_,
        "object": "chat.completion.chunk",
        "created": int(time.time()),
        "model": "jiuge",
        "system_fingerprint": None,
        "choices": [
            {
                "index": 0,
                "delta": delta,
                "logprobs": None,
                "finish_reason": finish_reason,
            }
        ],
    }

Pan Zezhong's avatar
Pan Zezhong committed
86
87
88

async def chat_stream(id_, request_data, request: Request):
    try:
Pan Zezhong's avatar
Pan Zezhong committed
89
90
        infer_task = InferTask(id_, MODEL.tokenizer, request_data)
        await App.state.kv_cache_pool.acquire(infer_task)
Pan Zezhong's avatar
Pan Zezhong committed
91
        chunk = json.dumps(
92
            chunk_json(id_, content="", role="assistant"),
Pan Zezhong's avatar
Pan Zezhong committed
93
94
95
96
            ensure_ascii=False,
        )
        yield f"{chunk}\n\n"

Pan Zezhong's avatar
Pan Zezhong committed
97
98
99
100
        async for token in MODEL.chat_stream_async(
            infer_task.request,
            infer_task.kvcache(),
        ):
Pan Zezhong's avatar
Pan Zezhong committed
101
102
103
104
            if await request.is_disconnected():
                print("Client disconnected. Aborting stream.")
                break
            chunk = json.dumps(
105
                chunk_json(id_, content=token),
Pan Zezhong's avatar
Pan Zezhong committed
106
107
108
109
                ensure_ascii=False,
            )
            yield f"{chunk}\n\n"
    finally:
Pan Zezhong's avatar
Pan Zezhong committed
110
        await App.state.kv_cache_pool.release(infer_task)
Pan Zezhong's avatar
Pan Zezhong committed
111
        chunk = json.dumps(
112
            chunk_json(id_, finish_reason="stop"),
Pan Zezhong's avatar
Pan Zezhong committed
113
114
115
116
117
            ensure_ascii=False,
        )
        yield f"{chunk}\n\n"


Pan Zezhong's avatar
Pan Zezhong committed
118
119
120
121
122
123
124
async def chat(id_, request_data):
    infer_task = InferTask(id_, MODEL.tokenizer, request_data)
    await App.state.kv_cache_pool.acquire(infer_task)

    output_text = MODEL.chat(
        infer_task.request,
        infer_task.kvcache(),
Pan Zezhong's avatar
Pan Zezhong committed
125
    )
126
127
128
    response = chunk_json(
        id_, content=output_text.strip(), role="assistant", finish_reason="stop"
    )
Pan Zezhong's avatar
Pan Zezhong committed
129
    await App.state.kv_cache_pool.release(infer_task)
Pan Zezhong's avatar
Pan Zezhong committed
130
131
132
    return JSONResponse(response)


Pan Zezhong's avatar
Pan Zezhong committed
133
@App.post("/chat/completions")
Pan Zezhong's avatar
Pan Zezhong committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
async def chat_completions(request: Request):
    data = await request.json()

    if not data.get("messages"):
        return JSONResponse(content={"error": "No message provided"}, status_code=400)

    stream = data.get("stream", False)
    id_ = f"cmpl-{uuid.uuid4().hex}"
    if stream:
        return StreamingResponse(
            chat_stream(id_, data, request), media_type="text/event-stream"
        )
    else:
        return chat(id_, data)


if __name__ == "__main__":
Pan Zezhong's avatar
Pan Zezhong committed
151
    uvicorn.run(App, host="0.0.0.0", port=8000)
Pan Zezhong's avatar
Pan Zezhong committed
152
153
154

"""
curl -N -H "Content-Type: application/json" \
Pan Zezhong's avatar
Pan Zezhong committed
155
     -X POST http://127.0.0.1:8000/chat/completions \
Pan Zezhong's avatar
Pan Zezhong committed
156
157
158
159
160
161
162
163
164
165
166
167
     -d '{
       "model": "jiuge",
       "messages": [
         {"role": "user", "content": "山东最高的山是?"}
       ],
       "temperature": 1.0,
       "top_k": 50,
       "top_p": 0.8,
       "max_tokens": 512,
       "stream": true
     }'
"""