Commit 51b1aade authored by Pan Zezhong's avatar Pan Zezhong
Browse files

add server script

parent fabfa2e2
...@@ -5,6 +5,7 @@ import safetensors ...@@ -5,6 +5,7 @@ import safetensors
import sys import sys
import time import time
import json import json
import asyncio
from libinfinicore_infer import ( from libinfinicore_infer import (
JiugeMeta, JiugeMeta,
...@@ -358,9 +359,122 @@ class JiugeForCauslLM: ...@@ -358,9 +359,122 @@ class JiugeForCauslLM:
load_end_time = time.time() load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s") print(f"Time used: {load_end_time - load_start_time:.3f}s")
def create_kv_cache(self):
return create_kv_cache(self.model_instance)
def drop_kv_cache(self, kv_cache):
drop_kv_cache(self.model_instance, kv_cache)
def chat(self, request, kv_cache):
messages = request.get("messages", [])
temperature = request.get("temperature", 1.0)
topk = request.get("top_k", 1)
topp = request.get("top_p", 1.0)
max_tokens = request.get("max_tokens", 512)
input_content = self.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
tokens = self.tokenizer.encode(input_content)
ntok = len(tokens)
nreq = 1
output_content = ""
tokens = (c_uint * ntok)(*tokens)
req_lens = (c_uint * nreq)(*[ntok])
req_pos = (c_uint * nreq)(*[0])
kv_caches = (POINTER(KVCache) * nreq)(*[kv_cache])
ans = (c_uint * nreq)()
steps = 0
for step_i in range(max_tokens):
infer_batch(
self.model_instance,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
ans,
temperature,
topk,
topp,
)
steps += 1
output_tokens = list(ans)
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output_content += output_str
if output_tokens[0] in self.eos_token_id:
break
req_pos[0] = req_pos[0] + ntok
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok])
def infer(self, input_list, topp=1.0, topk=1, temperature=1.0): return output_content
pass
async def chat_stream_async(self, request, kv_cache):
messages = request.get("messages", [])
temperature = request.get("temperature", 1.0)
topk = request.get("top_k", 1)
topp = request.get("top_p", 1.0)
max_tokens = request.get("max_tokens", 512)
input_content = self.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
tokens = self.tokenizer.encode(input_content)
ntok = len(tokens)
nreq = 1
tokens = (c_uint * ntok)(*tokens)
req_lens = (c_uint * nreq)(*[ntok])
req_pos = (c_uint * nreq)(*[0])
kv_caches = (POINTER(KVCache) * nreq)(*[kv_cache])
ans = (c_uint * nreq)()
for step_i in range(max_tokens):
infer_batch(
self.model_instance,
tokens,
ntok,
req_lens,
nreq,
req_pos,
kv_caches,
ans,
temperature,
topk,
topp,
)
output_tokens = list(ans)
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
yield output_str # Yield each token as it's produced
await asyncio.sleep(0) # Let event loop breathe
if output_tokens[0] in self.eos_token_id:
break
req_pos[0] += ntok
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok])
def generate(self, input_content, max_steps, topp=1.0, topk=1, temperature=1.0): def generate(self, input_content, max_steps, topp=1.0, topk=1, temperature=1.0):
kv_cache = create_kv_cache(self.model_instance) kv_cache = create_kv_cache(self.model_instance)
...@@ -433,7 +547,7 @@ class JiugeForCauslLM: ...@@ -433,7 +547,7 @@ class JiugeForCauslLM:
def test(): def test():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print( print(
"Usage: python test_llama.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]" "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
) )
sys.exit(1) sys.exit(1)
model_path = sys.argv[2] model_path = sys.argv[2]
...@@ -452,7 +566,7 @@ def test(): ...@@ -452,7 +566,7 @@ def test():
device_type = DeviceType.DEVICE_TYPE_MOORE device_type = DeviceType.DEVICE_TYPE_MOORE
else: else:
print( print(
"Usage: python test_llama.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]" "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
) )
sys.exit(1) sys.exit(1)
......
from jiuge import JiugeForCauslLM
from libinfinicore_infer import DeviceType
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
import asyncio
import uvicorn
import time
import uuid
import sys
import signal
import json
if len(sys.argv) < 3:
print(
"Usage: python launch_server.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)
model_path = sys.argv[2]
device_type = DeviceType.DEVICE_TYPE_CPU
if sys.argv[1] == "--cpu":
device_type = DeviceType.DEVICE_TYPE_CPU
elif sys.argv[1] == "--nvidia":
device_type = DeviceType.DEVICE_TYPE_NVIDIA
elif sys.argv[1] == "--cambricon":
device_type = DeviceType.DEVICE_TYPE_CAMBRICON
elif sys.argv[1] == "--ascend":
device_type = DeviceType.DEVICE_TYPE_ASCEND
elif sys.argv[1] == "--metax":
device_type = DeviceType.DEVICE_TYPE_METAX
elif sys.argv[1] == "--moore":
device_type = DeviceType.DEVICE_TYPE_MOORE
else:
print(
"Usage: python launch_server.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = JiugeForCauslLM(model_path, device_type, ndev)
kv_cache = model.create_kv_cache()
def signal_handler(sig, frame):
print(f"Received signal {sig}, cleaning up...")
model.drop_kv_cache(kv_cache)
model.destroy_model_instance()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle docker stop / system shutdown
app = FastAPI()
async def chat_stream(id_, request_data, request: Request):
try:
chunk = json.dumps(
{
"id": id_,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "jiuge",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"logprobs": None,
"finish_reason": None,
}
],
},
ensure_ascii=False,
)
yield f"{chunk}\n\n"
async for token in model.chat_stream_async(request_data, kv_cache):
if await request.is_disconnected():
print("Client disconnected. Aborting stream.")
break
chunk = json.dumps(
{
"id": id_,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "jiuge",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"delta": {"content": token},
"logprobs": None,
"finish_reason": None,
}
],
},
ensure_ascii=False,
)
yield f"{chunk}\n\n"
finally:
chunk = json.dumps(
{
"id": id_,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "jiuge",
"system_fingerprint": None,
"choices": [
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
],
},
ensure_ascii=False,
)
yield f"{chunk}\n\n"
def chat(id_, request_data):
output_text = model.chat(
request_data,
kv_cache,
)
response = {
"id": id_,
"object": "chat.completion",
"created": int(time.time()),
"model": "jiuge",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": output_text.strip()},
"finish_reason": "stop",
}
],
}
return JSONResponse(response)
@app.post("/jiuge/chat/completions")
async def chat_completions(request: Request):
data = await request.json()
if not data.get("messages"):
return JSONResponse(content={"error": "No message provided"}, status_code=400)
stream = data.get("stream", False)
id_ = f"cmpl-{uuid.uuid4().hex}"
if stream:
return StreamingResponse(
chat_stream(id_, data, request), media_type="text/event-stream"
)
else:
return chat(id_, data)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
"""
curl -N -H "Content-Type: application/json" \
-X POST http://127.0.0.1:8000/jiuge/chat/completions \
-d '{
"model": "jiuge",
"messages": [
{"role": "user", "content": "山东最高的山是?"}
],
"temperature": 1.0,
"top_k": 50,
"top_p": 0.8,
"max_tokens": 512,
"stream": true
}'
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment