"projects/Couplets/configs/config.py" did not exist on "fd158e88e82c3fa848017c62a7eccb49a5c64f78"
app.py 1.22 KB
Newer Older
ACzhangchao's avatar
ACzhangchao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

app = FastAPI()

# 加载模型和分词器
model_id = "/workspace/jiutian/JIUTIAN-139MoE-chat"  # 修改为你的模型路径
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16,
                                             trust_remote_code=True)


# 定义请求体
class ModelInput(BaseModel):
    text: str


@app.post("/predict/")
async def predict(request: Request, model_input: ModelInput):
    # 处理输入文本
    text = "Human:\n" + model_input.text + "\n\nAssistant:\n"

    # 分词和生成输出
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False, padding_side='left', truncation_side='left')
    outputs = model.generate(**inputs, max_new_tokens=64, repetition_penalty=1.03, do_sample=False, eos_token_id=0)

    # 解码输出
    response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return {"response": response_text}


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)