# coding=utf-8 # Implements API for ChatGLM3-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) # Usage: python openai_api.py # Visit http://localhost:8100/docs for documents. import time import json import torch import uvicorn import argparse from pydantic import BaseModel, Field from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from typing import Any, Dict, List, Literal, Optional, Union #from transformers import AutoTokenizer, AutoModel from sse_starlette.sse import ServerSentEvent, EventSourceResponse from fastllm_pytools import llm @asynccontextmanager async def lifespan(app: FastAPI): # collects GPU memory yield global device_map if torch.cuda.is_available(): for device in device_map: with torch.cuda.device(device): torch.cuda.empty_cache() torch.cuda.ipc_collect() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ModelCard(BaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "owner" root: Optional[str] = None parent: Optional[str] = None permission: Optional[list] = None class ModelList(BaseModel): object: str = "list" data: List[ModelCard] = [] class ChatMessage(BaseModel): role: Literal["user", "assistant", "system"] content: str class Usage(BaseModel): prompt_tokens: int = None total_tokens: int = None completion_tokens: int = None class DeltaMessage(BaseModel): role: Optional[Literal["user", "assistant", "system"]] = None content: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = None top_p: Optional[float] = None max_length: Optional[int] = None stream: Optional[bool] = False class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage finish_reason: Literal["stop", "length"] class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage finish_reason: Optional[Literal["stop", "length"]] class ChatCompletionResponse(BaseModel): id: str object: Literal["chat.completion", "chat.completion.chunk"] created: Optional[int] = Field(default_factory=lambda: int(time.time())) model: str choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] usage: Usage = None @app.get("/v1/models", response_model=ModelList) def list_models(): global model_list for model in model_list: ModelCard(id=model) ModelList.data.append(ModelCard) return ModelList() @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) def create_chat_completion(request: ChatCompletionRequest): if request.model not in model_list: raise HTTPException(status_code=400, detail="Invalid Model Name") global model id = "chatcmpl-A" if request.messages[-1].role != "user": raise HTTPException(status_code=400, detail="Invalid request") query = request.messages[-1].content if request.max_length is not None: max_length = request.max_length else: max_length = 1024 if request.temperature is not None: temperature = request.temperature else: temperature = 0.1 if request.top_p is not None: top_p = request.top_p else: top_p = 0.8 prev_messages = request.messages[:-1] # print(prev_messages) if len(prev_messages) > 0 and prev_messages[0].role == "system": query = prev_messages.pop(0).content + query history = [] if len(prev_messages) % 2 == 0: for i in range(0, len(prev_messages), 2): if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": history.append([prev_messages[i].content, prev_messages[i+1].content]) if request.stream: generate = predict(id=id, query=query, history=history, max_length=max_length, top_p = top_p, temperature = temperature, model_id = request.model) return EventSourceResponse(generate, media_type="text/event-stream") response = model.response(query=query, history=history, max_length=max_length, top_p = top_p, temperature = temperature) choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content=response), finish_reason="stop" ) prompt_tokens = len(model.tokenizer_encode_string(query)) completion_tokens = len(model.tokenizer_encode_string(response)) usage = Usage( prompt_tokens = prompt_tokens, completion_tokens = completion_tokens, total_tokens = prompt_tokens+completion_tokens, ) return ChatCompletionResponse(id=id ,model=request.model, choices=[choice_data], object="chat.completion", usage=usage) def predict(id: str, query: str, history: List[List[str]], model_id: str, max_length: int, top_p: float, temperature: float): global model creat_time = int(time.time()) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) chunk = ChatCompletionResponse(id=id, created=creat_time, model=model_id, choices=[choice_data], object="chat.completion.chunk") #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) //pydantic从1.8.0开始不支持dumps_kwags参数,参考https://github.com/THUDM/ChatGLM2-6B/issues/308 yield json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False) for new_response in model.stream_response(query=query, history=history, max_length=max_length, top_p = top_p, temperature = temperature): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=new_response), finish_reason=None ) chunk = ChatCompletionResponse(id=id, created=creat_time, model=model_id, choices=[choice_data], object="chat.completion.chunk") #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), finish_reason="stop" ) chunk = ChatCompletionResponse(id=id, created=creat_time, model=model_id, choices=[choice_data], object="chat.completion.chunk") #yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False) yield '[DONE]' def args_parser(): parser = argparse.ArgumentParser(description = 'baichuan2_chat_demo') parser.add_argument('-p', '--path', type = str, default = "/model", help = '模型文件的路径') parser.add_argument('-g', '--gpus', type = str, default = "0", help = '指定运行的gpu卡,例如“0,1”') args = parser.parse_args() return args if __name__ == "__main__": args = args_parser() global model_list model_list = ["chatglm3-6b-fastllm"] global device_map device_map = ["cuda:"+num for num in args.gpus.split(',')] llm.set_device_map(device_map) model = llm.model(args.path) uvicorn.run(app, host='127.0.0.1', port=8100)