vllm_fastapi.py 1.84 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer

app = FastAPI()


class ChatRequest(BaseModel):
    user_input: str
    history: list


tokenizer = None
model = None


@app.on_event("startup")
def load_model_and_tokenizer():
    global tokenizer, model
    path = "AIDC-AI/Marco-o1"
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
    model = LLM(model=path, tensor_parallel_size=4)


def generate_response(model, text, max_new_tokens=4096):
    new_output = ''
    sampling_params = SamplingParams(
        max_tokens=1,
        temperature=0,
        top_p=0.9
    )
    with torch.inference_mode():
        for _ in range(max_new_tokens):
            outputs = model.generate(
                [f'{text}{new_output}'],
                sampling_params=sampling_params,
                use_tqdm=False
            )
            new_output += outputs[0].outputs[0].text
            if new_output.endswith('</Output>'):
                break
    return new_output


@app.post("/chat/")
async def chat(request: ChatRequest):
    if not request.user_input:
        raise HTTPException(status_code=400, detail="Input cannot be empty.")

    if request.user_input.lower() in ['q', 'quit']:
        return {"response": "Exiting chat."}

    if request.user_input.lower() == 'c':
        request.history.clear()
        return {"response": "Clearing chat history."}

    request.history.append({"role": "user", "content": request.user_input})
    text = tokenizer.apply_chat_template(request.history, tokenize=False, add_generation_prompt=True)
    response = generate_response(model, text)
    request.history.append({"role": "assistant", "content": response})

    return {"response": response, "history": request.history}