text_generation_server.py 9.96 KB
Newer Older
wangxj's avatar
wangxj committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
3
import datetime
import json
wangxj's avatar
wangxj committed
4
5
import os
import sys
xingjinliang's avatar
xingjinliang committed
6
7

from flask import Flask, request, jsonify
rprenger's avatar
rprenger committed
8
from flask_restful import Resource, Api
mshoeybi's avatar
mshoeybi committed
9

wangxj's avatar
wangxj committed
10
from megatron.core.inference.sampling_params import SamplingParams
xingjinliang's avatar
xingjinliang committed
11
12
from megatron.inference.endpoints.common import send_do_generate, send_do_beam_search, LOCK
from megatron.inference.endpoints.completions import MegatronCompletions
wangxj's avatar
wangxj committed
13
14
15
16
17
from megatron.inference.text_generation import beam_search_and_post_process

sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
)
rprenger's avatar
rprenger committed
18
19
20


class MegatronGenerate(Resource):
wangxj's avatar
wangxj committed
21
22
23
    def __init__(self, engine, args):
        self.engine = engine
        self.args = args
24

rprenger's avatar
rprenger committed
25
    def put(self):
26
27
        if not "prompts" in request.get_json():
            return "prompts argument required", 400
xingjinliang's avatar
xingjinliang committed
28

29
30
        if "max_len" in request.get_json():
            return "max_len is no longer used.  Replace with tokens_to_generate", 400
xingjinliang's avatar
xingjinliang committed
31

32
33
34
35
        if "sentences" in request.get_json():
            return "sentences is no longer used.  Replace with prompts", 400

        prompts = request.get_json()["prompts"]
36
37
38
39
40
        if not isinstance(prompts, list):
            return "prompts is not a list of strings", 400

        if len(prompts) == 0:
            return "prompts is empty", 400
xingjinliang's avatar
xingjinliang committed
41

42
43
        if len(prompts) > 128:
            return "Maximum number of prompts is 128", 400
xingjinliang's avatar
xingjinliang committed
44

45
46
47
48
49
        tokens_to_generate = 64  # Choosing hopefully sane default.  Full sequence is slow
        if "tokens_to_generate" in request.get_json():
            tokens_to_generate = request.get_json()["tokens_to_generate"]
            if not isinstance(tokens_to_generate, int):
                return "tokens_to_generate must be an integer greater than 0"
50
51
            if tokens_to_generate < 0:
                return "tokens_to_generate must be an integer greater than or equal to 0"
rprenger's avatar
rprenger committed
52

53
54
55
56
57
        logprobs = False
        if "logprobs" in request.get_json():
            logprobs = request.get_json()["logprobs"]
            if not isinstance(logprobs, bool):
                return "logprobs must be a boolean value"
xingjinliang's avatar
xingjinliang committed
58

59
60
        if tokens_to_generate == 0 and not logprobs:
            return "tokens_to_generate=0 implies logprobs should be True"
xingjinliang's avatar
xingjinliang committed
61

62
        temperature = 1.0
63
64
        if "temperature" in request.get_json():
            temperature = request.get_json()["temperature"]
xingjinliang's avatar
xingjinliang committed
65
66
            if not (isinstance(temperature, (int, float))):
                return "temperature must be a positive number less than or equal to 1000.0"
67
68
            if not (0.0 < temperature <= 100.0):
                return "temperature must be a positive number less than or equal to 100.0"
xingjinliang's avatar
xingjinliang committed
69
70

        top_k = 0
71
72
        if "top_k" in request.get_json():
            top_k = request.get_json()["top_k"]
xingjinliang's avatar
xingjinliang committed
73
            if not (isinstance(top_k, int)):
74
                return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
75
            if not (0 <= top_k <= 1000):
76
                return "top_k must be equal to or greater than 0 and less than or equal to 1000"
xingjinliang's avatar
xingjinliang committed
77

78
        top_p = 0.0
79
80
        if "top_p" in request.get_json():
            top_p = request.get_json()["top_p"]
xingjinliang's avatar
xingjinliang committed
81
            if not (isinstance(top_p, float)):
82
                return "top_p must be a positive float less than or equal to 1.0"
83
84
            if top_p > 0.0 and top_k > 0.0:
                return "cannot set both top-k and top-p samplings."
85
            if not (0 <= top_p <= 1.0):
86
                return "top_p must be less than or equal to 1.0"
xingjinliang's avatar
xingjinliang committed
87

88
89
90
        top_p_decay = 0.0
        if "top_p_decay" in request.get_json():
            top_p_decay = request.get_json()["top_p_decay"]
xingjinliang's avatar
xingjinliang committed
91
            if not (isinstance(top_p_decay, float)):
92
                return "top_p_decay must be a positive float less than or equal to 1.0"
93
            if top_p == 0.0:
94
95
96
                return "top_p_decay cannot be set without top_p"
            if not (0 <= top_p_decay <= 1.0):
                return "top_p_decay must be less than or equal to 1.0"
xingjinliang's avatar
xingjinliang committed
97

98
99
100
        top_p_bound = 0.0
        if "top_p_bound" in request.get_json():
            top_p_bound = request.get_json()["top_p_bound"]
xingjinliang's avatar
xingjinliang committed
101
            if not (isinstance(top_p_bound, float)):
102
                return "top_p_bound must be a positive float less than or equal to top_p"
103
            if top_p == 0.0:
104
105
106
                return "top_p_bound cannot be set without top_p"
            if not (0.0 < top_p_bound <= top_p):
                return "top_p_bound must be greater than 0 and less than top_p"
xingjinliang's avatar
xingjinliang committed
107

108
109
110
111
112
        add_BOS = False
        if "add_BOS" in request.get_json():
            add_BOS = request.get_json()["add_BOS"]
            if not isinstance(add_BOS, bool):
                return "add_BOS must be a boolean value"
xingjinliang's avatar
xingjinliang committed
113

114
115
116
117
118
119
120
121
        if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
            return "Empty prompts require add_BOS=true"

        stop_on_double_eol = False
        if "stop_on_double_eol" in request.get_json():
            stop_on_double_eol = request.get_json()["stop_on_double_eol"]
            if not isinstance(stop_on_double_eol, bool):
                return "stop_on_double_eol must be a boolean value"
xingjinliang's avatar
xingjinliang committed
122

123
124
125
126
127
128
        stop_on_eol = False
        if "stop_on_eol" in request.get_json():
            stop_on_eol = request.get_json()["stop_on_eol"]
            if not isinstance(stop_on_eol, bool):
                return "stop_on_eol must be a boolean value"

Peng Xu's avatar
Peng Xu committed
129
130
131
132
133
134
        prevent_newline_after_colon = False
        if "prevent_newline_after_colon" in request.get_json():
            prevent_newline_after_colon = request.get_json()["prevent_newline_after_colon"]
            if not isinstance(prevent_newline_after_colon, bool):
                return "prevent_newline_after_colon must be a boolean value"

135
136
137
138
139
        random_seed = -1
        if "random_seed" in request.get_json():
            random_seed = request.get_json()["random_seed"]
            if not isinstance(random_seed, int):
                return "random_seed must be integer"
xingjinliang's avatar
xingjinliang committed
140
            if random_seed < 0:
141
142
                return "random_seed must be a positive integer"

rprenger's avatar
rprenger committed
143
144
145
146
147
        no_log = False
        if "no_log" in request.get_json():
            no_log = request.get_json()["no_log"]
            if not isinstance(no_log, bool):
                return "no_log must be a boolean value"
xingjinliang's avatar
xingjinliang committed
148

149
150
151
152
153
154
155
156
157
        beam_width = None
        if "beam_width" in request.get_json():
            beam_width = request.get_json()["beam_width"]
            if not isinstance(beam_width, int):
                return "beam_width must be integer"
            if beam_width < 1:
                return "beam_width must be an integer > 1"
            if len(prompts) > 1:
                return "When doing beam_search, batch size must be 1"
rprenger's avatar
rprenger committed
158

xingjinliang's avatar
xingjinliang committed
159
        stop_token = 50256
160
161
162
163
        if "stop_token" in request.get_json():
            stop_token = request.get_json()["stop_token"]
            if not isinstance(stop_token, int):
                return "stop_token must be an integer"
xingjinliang's avatar
xingjinliang committed
164
165

        length_penalty = 1
166
167
168
169
        if "length_penalty" in request.get_json():
            length_penalty = request.get_json()["length_penalty"]
            if not isinstance(length_penalty, float):
                return "length_penalty must be a float"
xingjinliang's avatar
xingjinliang committed
170
171
172

        with LOCK:  # Need to get lock to keep multiple threads from hitting code

rprenger's avatar
rprenger committed
173
174
            if not no_log:
                print("request IP: " + str(request.remote_addr))
xingjinliang's avatar
xingjinliang committed
175
                print(json.dumps(request.get_json()), flush=True)
rprenger's avatar
rprenger committed
176
                print("start time: ", datetime.datetime.now())
xingjinliang's avatar
xingjinliang committed
177

178
            try:
179
                if beam_width is not None:
xingjinliang's avatar
xingjinliang committed
180
181
                    send_do_beam_search()  # Tell other ranks we're doing beam_search
                    response, response_seg, response_scores = beam_search_and_post_process(
rprenger's avatar
rprenger committed
182
183
184
                        self.model,
                        prompts=prompts,
                        tokens_to_generate=tokens_to_generate,
xingjinliang's avatar
xingjinliang committed
185
                        beam_size=beam_width,
186
187
188
                        add_BOS=add_BOS,
                        stop_token=stop_token,
                        num_return_gen=beam_width,  # Returning whole beam
Peng Xu's avatar
Peng Xu committed
189
                        length_penalty=length_penalty,
xingjinliang's avatar
xingjinliang committed
190
191
192
193
194
195
                        prevent_newline_after_colon=prevent_newline_after_colon,
                    )

                    return jsonify(
                        {"text": response, "segments": response_seg, "scores": response_scores}
                    )
rprenger's avatar
rprenger committed
196
                else:
xingjinliang's avatar
xingjinliang committed
197
                    send_do_generate()  # Tell other ranks we're doing generate
wangxj's avatar
wangxj committed
198
199

                    sampling_params = SamplingParams(
200
                        temperature=temperature,
wangxj's avatar
wangxj committed
201
202
203
204
205
                        top_k=top_k,
                        top_p=top_p,
                        return_segments=True,
                        return_log_probs=logprobs,
                        num_tokens_to_generate=tokens_to_generate,
xingjinliang's avatar
xingjinliang committed
206
                    )
wangxj's avatar
wangxj committed
207
208
209
210
211
212
213
214
215
216
217
218
                    result = list(
                        self.engine.generate(
                            prompts=prompts, common_inference_params=sampling_params
                        )
                    )
                    response_dict = {"text": [x.prompt + x.generated_text for x in result]}
                    if sampling_params.return_log_probs:
                        response_logprobs = [x.prompt_log_probs + x.generated_log_probs for x in
                                             result]
                        response_dict["logprobs"] = response_logprobs
                    if sampling_params.return_segments:
                        response_dict["segments"] = [x.segments for x in result]
rprenger's avatar
rprenger committed
219

wangxj's avatar
wangxj committed
220
                    return jsonify(response_dict)
rprenger's avatar
rprenger committed
221

222
            except ValueError as ve:
223
                return ve.args[0]
xingjinliang's avatar
xingjinliang committed
224

rprenger's avatar
rprenger committed
225
226

class MegatronServer(object):
wangxj's avatar
wangxj committed
227
    def __init__(self, model, args=None):
228
        self.app = Flask(__name__, static_url_path='')
rprenger's avatar
rprenger committed
229
        api = Api(self.app)
wangxj's avatar
wangxj committed
230
        api.add_resource(MegatronGenerate, '/api', resource_class_args=[model, args])
xingjinliang's avatar
xingjinliang committed
231
232
233
        api.add_resource(MegatronCompletions, '/completions', resource_class_args=[model])

    def run(self, url, port):
liangjing's avatar
v1  
liangjing committed
234
        self.app.run(url, threaded=True, debug=False, port=port)