server.py 7.05 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
import argparse
import os
from threading import Lock
4
from typing import Generator, List, Optional
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
5
6
7

import torch
import uvicorn
8
9
from coati.quant import llama_load_quant, low_resource_init
from fastapi import FastAPI, Request
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
10
11
12
13
14
15
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
16
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
17
from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
18

19
CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
ver217's avatar
ver217 committed
20
MAX_LEN = 512
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
21
22
23
24
25
26
27
28
29
running_lock = Lock()


class GenerationTaskReq(BaseModel):
    max_new_tokens: int = Field(gt=0, le=512, example=64)
    history: List[Dialogue] = Field(min_items=1)
    top_k: Optional[int] = Field(default=None, gt=0, example=50)
    top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
    temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
30
    repetition_penalty: Optional[float] = Field(default=None, gt=1.0, example=1.2)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
31
32
33
34
35
36
37
38


limiter = Limiter(key_func=get_remote_address)
app = FastAPI()
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# set CORS
39
origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
40
41
42

if origin_spec_from_env is not None:
    # allow CORS from the specified origins
43
    origins = os.environ["CORS_ORIGIN"].split(",")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
else:
    # allow CORS from all origins
    origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
    inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
59
    # TODO(ver217): streaming generation does not support repetition_penalty now
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
60
    model_kwargs = {
61
62
63
64
65
66
67
        "max_generate_tokens": max_new_tokens,
        "early_stopping": True,
        "top_k": top_k,
        "top_p": top_p,
        "temperature": temperature,
        "prepare_inputs_fn": model.prepare_inputs_for_generation,
        "update_model_kwargs_fn": update_model_kwargs_fn,
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    }
    is_first_word = True
    generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
    for output in generator:
        output = output.cpu()
        tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True)
        current_sub_tokens = []
        for token in tokens:
            if token in tokenizer.all_special_tokens:
                continue
            current_sub_tokens.append(token)
        if current_sub_tokens:
            out_string = tokenizer.sp_model.decode(current_sub_tokens)
            if is_first_word:
                out_string = out_string.lstrip()
                is_first_word = False
84
            elif current_sub_tokens[0].startswith("▁"):
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
85
                # whitespace will be ignored by the frontend
86
                out_string = " " + out_string
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
87
88
89
90
91
92
93
94
            yield out_string


async def event_generator(request: Request, generator: Generator):
    while True:
        if await request.is_disconnected():
            break
        try:
95
            yield {"event": "generate", "data": next(generator)}
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
96
        except StopIteration:
97
            yield {"event": "end", "data": ""}
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
98
99
100
            break


101
102
@app.post("/generate/stream")
@limiter.limit("1/second")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
103
104
105
def generate(data: GenerationTaskReq, request: Request):
    prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
    event_source = event_generator(
106
107
        request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
108
109
110
    return EventSourceResponse(event_source)


111
112
@app.post("/generate")
@limiter.limit("1/second")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
113
114
def generate_no_stream(data: GenerationTaskReq, request: Request):
    prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
115
116
    if prompt_processor.has_censored_words(prompt):
        return prompt_processor.SAFE_RESPONSE
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
117
118
    inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
    with running_lock:
119
        output = model.generate(**inputs, **data.dict(exclude={"history"}))
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
120
    output = output.cpu()
121
    prompt_len = inputs["input_ids"].size(1)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
122
123
    response = output[0, prompt_len:]
    out_string = tokenizer.decode(response, skip_special_tokens=True)
124
125
126
127
    out_string = prompt_processor.postprocess_output(out_string)
    if prompt_processor.has_censored_words(out_string):
        return prompt_processor.SAFE_RESPONSE
    return out_string
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
128
129


130
if __name__ == "__main__":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
131
132
    parser = argparse.ArgumentParser()
    parser.add_argument(
133
134
135
        "pretrained",
        help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
136
    parser.add_argument(
137
138
        "--quant",
        choices=["8bit", "4bit"],
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
139
        default=None,
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        help="Quantization mode. Default: None (no quantization, fp16).",
    )
    parser.add_argument(
        "--gptq_checkpoint",
        default=None,
        help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
    )
    parser.add_argument(
        "--gptq_group_size",
        type=int,
        default=128,
        help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
    )
    parser.add_argument("--http_host", default="0.0.0.0")
    parser.add_argument("--http_port", type=int, default=7070)
    parser.add_argument(
        "--profanity_file",
        default=None,
        help="Path to profanity words list. It should be a JSON file containing a list of words.",
    )
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
160
161
    args = parser.parse_args()

162
163
    if args.quant == "4bit":
        assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
164
165

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
166
167
168
169
170
171

    if args.profanity_file is not None:
        censored_words = load_json(args.profanity_file)
    else:
        censored_words = []
    prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
172

173
    if args.quant == "4bit":
174
175
176
177
        with low_resource_init():
            config = LlamaConfig.from_pretrained(args.pretrained)
            model = LlamaForCausalLM(config)
        model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
178
179
180
181
        model.cuda()
    else:
        model = LlamaForCausalLM.from_pretrained(
            args.pretrained,
182
            load_in_8bit=(args.quant == "8bit"),
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
183
184
185
            torch_dtype=torch.float16,
            device_map="auto",
        )
186
187
        if args.quant != "8bit":
            model.half()  # seems to fix bugs for some users.
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
188
189
190
191
192
        model.eval()

    config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
    server = uvicorn.Server(config=config)
    server.run()