opt_fastapi.py 5.05 KB
Newer Older
1
2
3
4
5
6
import argparse
import logging
import random
from typing import Optional

import uvicorn
7
8
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
9
10
11
12
13
14
15
16
17
18
from energonai import QueueFullError, launch_engine
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from transformers import GPT2Tokenizer


class GenerationTaskReq(BaseModel):
    max_tokens: int = Field(gt=0, le=256, example=64)
    prompt: str = Field(
19
20
21
        min_length=1,
        example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:",
    )
22
23
24
25
26
27
28
29
    top_k: Optional[int] = Field(default=None, gt=0, example=50)
    top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
    temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)


app = FastAPI()


30
@app.post("/generation")
31
32
33
34
35
36
37
38
async def generate(data: GenerationTaskReq, request: Request):
    logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
    key = (data.prompt, data.max_tokens)
    try:
        if cache is None:
            raise MissCacheError()
        outputs = cache.get(key)
        output = random.choice(outputs)
39
        logger.info("Cache hit")
40
41
    except MissCacheError:
        inputs = tokenizer(data.prompt, truncation=True, max_length=512)
42
43
44
45
        inputs["max_tokens"] = data.max_tokens
        inputs["top_k"] = data.top_k
        inputs["top_p"] = data.top_p
        inputs["temperature"] = data.temperature
46
47
48
49
50
51
52
53
54
55
        try:
            uid = id(data)
            engine.submit(uid, inputs)
            output = await engine.wait(uid)
            output = tokenizer.decode(output, skip_special_tokens=True)
            if cache is not None:
                cache.add(key, output)
        except QueueFullError as e:
            raise HTTPException(status_code=406, detail=e.args[0])

56
    return {"text": output}
57
58
59
60
61
62
63
64
65
66
67


@app.on_event("shutdown")
async def shutdown(*_):
    engine.shutdown()
    server.should_exit = True
    server.force_exit = True
    await server.shutdown()


def get_model_fn(model_name: str):
68
    model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B}
69
70
71
72
    return model_map[model_name]


def print_args(args: argparse.Namespace):
73
    print("\n==> Args:")
74
    for k, v in args.__dict__.items():
75
        print(f"{k} = {v}")
76
77
78


FIXED_CACHE_KEYS = [
79
80
81
82
83
84
85
86
87
88
89
90
    (
        "Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:",
        64,
    ),
    (
        "A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.",
        64,
    ),
    (
        "English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:",
        64,
    ),
91
92
]

93
if __name__ == "__main__":
94
    parser = argparse.ArgumentParser()
95
96
97
98
99
100
101
102
103
104
105
106
107
    parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"])
    parser.add_argument("--tp", type=int, default=1)
    parser.add_argument("--master_host", default="localhost")
    parser.add_argument("--master_port", type=int, default=19990)
    parser.add_argument("--rpc_port", type=int, default=19980)
    parser.add_argument("--max_batch_size", type=int, default=8)
    parser.add_argument("--pipe_size", type=int, default=1)
    parser.add_argument("--queue_size", type=int, default=0)
    parser.add_argument("--http_host", default="0.0.0.0")
    parser.add_argument("--http_port", type=int, default=7070)
    parser.add_argument("--checkpoint", default=None)
    parser.add_argument("--cache_size", type=int, default=0)
    parser.add_argument("--cache_list_size", type=int, default=1)
108
109
110
111
    args = parser.parse_args()
    print_args(args)
    model_kwargs = {}
    if args.checkpoint is not None:
112
        model_kwargs["checkpoint"] = args.checkpoint
113
114

    logger = logging.getLogger(__name__)
115
    tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b")
116
117
118
119
    if args.cache_size > 0:
        cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
    else:
        cache = None
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    engine = launch_engine(
        args.tp,
        1,
        args.master_host,
        args.master_port,
        args.rpc_port,
        get_model_fn(args.model),
        batch_manager=BatchManagerForGeneration(
            max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id
        ),
        pipe_size=args.pipe_size,
        queue_size=args.queue_size,
        **model_kwargs,
    )
134
135
136
    config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
    server = uvicorn.Server(config=config)
    server.run()