text_generation_server.py 10.1 KB
Newer Older
rprenger's avatar
rprenger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import datetime
rprenger's avatar
rprenger committed
16
import torch
17
import json
18
import threading
19
from flask import Flask, request, jsonify, current_app
rprenger's avatar
rprenger committed
20
21
from flask_restful import Resource, Api
from megatron import get_args
22
from megatron.text_generation import generate_and_post_process
rprenger's avatar
rprenger committed
23
from megatron.text_generation import beam_search_and_post_process
mshoeybi's avatar
mshoeybi committed
24

rprenger's avatar
rprenger committed
25
26

GENERATE_NUM = 0
27
BEAM_NUM = 1
rprenger's avatar
rprenger committed
28
lock = threading.Lock()
rprenger's avatar
rprenger committed
29
30
31
32

class MegatronGenerate(Resource):
    def __init__(self, model):
        self.model = model
33

rprenger's avatar
rprenger committed
34
35
36
    @staticmethod
    def send_do_generate():
        choice = torch.cuda.LongTensor([GENERATE_NUM])
37
        torch.distributed.broadcast(choice, 0)
38
     
rprenger's avatar
rprenger committed
39
40
41
42
43
    @staticmethod
    def send_do_beam_search():
        choice = torch.cuda.LongTensor([BEAM_NUM])
        torch.distributed.broadcast(choice, 0)
    
rprenger's avatar
rprenger committed
44
    def put(self):
45
        args = get_args()
46
47
48
       
        if not "prompts" in request.get_json():
            return "prompts argument required", 400
49
        
50
51
52
53
54
55
56
57
58
        if "max_len" in request.get_json():
            return "max_len is no longer used.  Replace with tokens_to_generate", 400
        
        if "sentences" in request.get_json():
            return "sentences is no longer used.  Replace with prompts", 400

        prompts = request.get_json()["prompts"]
        if len(prompts) > 128:
            return "Maximum number of prompts is 128", 400
59

60
61
62
63
64
        tokens_to_generate = 64  # Choosing hopefully sane default.  Full sequence is slow
        if "tokens_to_generate" in request.get_json():
            tokens_to_generate = request.get_json()["tokens_to_generate"]
            if not isinstance(tokens_to_generate, int):
                return "tokens_to_generate must be an integer greater than 0"
65
66
            if tokens_to_generate < 0:
                return "tokens_to_generate must be an integer greater than or equal to 0"
rprenger's avatar
rprenger committed
67

68
69
70
71
72
        logprobs = False
        if "logprobs" in request.get_json():
            logprobs = request.get_json()["logprobs"]
            if not isinstance(logprobs, bool):
                return "logprobs must be a boolean value"
73
        
74
75
        if tokens_to_generate == 0 and not logprobs:
            return "tokens_to_generate=0 implies logprobs should be True"
76
        
77
        temperature = 1.0
78
79
        if "temperature" in request.get_json():
            temperature = request.get_json()["temperature"]
80
81
82
83
            if not (type(temperature) == int or type(temperature) == float):
                return "temperature must be a positive number less than or equal to 100.0"
            if not (0.0 < temperature <= 100.0):
                return "temperature must be a positive number less than or equal to 100.0"
rprenger's avatar
rprenger committed
84
        
85
        top_k = 0.0
86
87
88
89
        if "top_k" in request.get_json():
            top_k = request.get_json()["top_k"]
            if not (type(top_k) == int):
                return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
90
            if not (0 <= top_k <= 1000):
91
92
                return "top_k must be equal to or greater than 0 and less than or equal to 1000"
        
93
        top_p = 0.0
94
95
96
97
        if "top_p" in request.get_json():
            top_p = request.get_json()["top_p"]
            if not (type(top_p) == float):
                return "top_p must be a positive float less than or equal to 1.0"
98
99
            if top_p > 0.0 and top_k > 0.0:
                return "cannot set both top-k and top-p samplings."
100
            if not (0 <= top_p <= 1.0):
101
102
                return "top_p must be less than or equal to 1.0"
        
103
104
105
106
107
        top_p_decay = 0.0
        if "top_p_decay" in request.get_json():
            top_p_decay = request.get_json()["top_p_decay"]
            if not (type(top_p_decay) == float):
                return "top_p_decay must be a positive float less than or equal to 1.0"
108
            if top_p == 0.0:
109
110
111
                return "top_p_decay cannot be set without top_p"
            if not (0 <= top_p_decay <= 1.0):
                return "top_p_decay must be less than or equal to 1.0"
112
        
113
114
115
116
117
        top_p_bound = 0.0
        if "top_p_bound" in request.get_json():
            top_p_bound = request.get_json()["top_p_bound"]
            if not (type(top_p_bound) == float):
                return "top_p_bound must be a positive float less than or equal to top_p"
118
            if top_p == 0.0:
119
120
121
                return "top_p_bound cannot be set without top_p"
            if not (0.0 < top_p_bound <= top_p):
                return "top_p_bound must be greater than 0 and less than top_p"
122
        
123
124
125
126
127
        add_BOS = False
        if "add_BOS" in request.get_json():
            add_BOS = request.get_json()["add_BOS"]
            if not isinstance(add_BOS, bool):
                return "add_BOS must be a boolean value"
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        
        if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
            return "Empty prompts require add_BOS=true"

        stop_on_double_eol = False
        if "stop_on_double_eol" in request.get_json():
            stop_on_double_eol = request.get_json()["stop_on_double_eol"]
            if not isinstance(stop_on_double_eol, bool):
                return "stop_on_double_eol must be a boolean value"
        
        stop_on_eol = False
        if "stop_on_eol" in request.get_json():
            stop_on_eol = request.get_json()["stop_on_eol"]
            if not isinstance(stop_on_eol, bool):
                return "stop_on_eol must be a boolean value"

144
145
146
147
148
149
150
151
        random_seed = -1
        if "random_seed" in request.get_json():
            random_seed = request.get_json()["random_seed"]
            if not isinstance(random_seed, int):
                return "random_seed must be integer"
            if random_seed < 0: 
                return "random_seed must be a positive integer"

rprenger's avatar
rprenger committed
152
153
154
155
156
157
        no_log = False
        if "no_log" in request.get_json():
            no_log = request.get_json()["no_log"]
            if not isinstance(no_log, bool):
                return "no_log must be a boolean value"
        
158
159
160
161
162
163
164
165
166
        beam_width = None
        if "beam_width" in request.get_json():
            beam_width = request.get_json()["beam_width"]
            if not isinstance(beam_width, int):
                return "beam_width must be integer"
            if beam_width < 1:
                return "beam_width must be an integer > 1"
            if len(prompts) > 1:
                return "When doing beam_search, batch size must be 1"
rprenger's avatar
rprenger committed
167

168
169
170
171
172
173
174
175
176
177
178
179
        stop_token=50256
        if "stop_token" in request.get_json():
            stop_token = request.get_json()["stop_token"]
            if not isinstance(stop_token, int):
                return "stop_token must be an integer"
        
        length_penalty = 1 
        if "length_penalty" in request.get_json():
            length_penalty = request.get_json()["length_penalty"]
            if not isinstance(length_penalty, float):
                return "length_penalty must be a float"
        
rprenger's avatar
rprenger committed
180
        with lock:  # Need to get lock to keep multiple threads from hitting code
rprenger's avatar
rprenger committed
181
            
rprenger's avatar
rprenger committed
182
183
184
185
            if not no_log:
                print("request IP: " + str(request.remote_addr))
                print(json.dumps(request.get_json()),flush=True)
                print("start time: ", datetime.datetime.now())
rprenger's avatar
rprenger committed
186
            
187
            try:
188
                if beam_width is not None:
rprenger's avatar
rprenger committed
189
190
191
192
193
194
                    MegatronGenerate.send_do_beam_search()  # Tell other ranks we're doing beam_search
                    response, response_seg, response_scores = \
                        beam_search_and_post_process(
                        self.model,
                        prompts=prompts,
                        tokens_to_generate=tokens_to_generate,
195
196
197
198
199
200
                        beam_size = beam_width,
                        add_BOS=add_BOS,
                        stop_token=stop_token,
                        num_return_gen=beam_width,  # Returning whole beam
                        length_penalty=length_penalty
                        )
rprenger's avatar
rprenger committed
201
202
203
204
205
206
207
208
                    
                    return jsonify({"text": response,
                        "segments": response_seg,
                        "scores": response_scores})
                else:
                    MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
                    response, response_seg, response_logprobs, _ = \
                        generate_and_post_process(
209
210
211
212
213
214
                        self.model,
                        prompts=prompts,
                        tokens_to_generate=tokens_to_generate,
                        return_output_log_probs=logprobs,
                        top_k_sampling=top_k,
                        top_p_sampling=top_p,
215
216
                        top_p_decay=top_p_decay,
                        top_p_bound=top_p_bound,
217
218
                        temperature=temperature,
                        add_BOS=add_BOS,
219
220
                        use_eod_token_for_early_termination=True,
                        stop_on_double_eol=stop_on_double_eol,
221
222
                        stop_on_eol=stop_on_eol,
                        random_seed=random_seed)
rprenger's avatar
rprenger committed
223
224
225
226
227

                    return jsonify({"text": response,
                        "segments": response_seg,
                        "logprobs": response_logprobs})

228
229
230
            except ValueError as ve:
                return "Length of prompt + tokens_to_generate longer than allowed"
            print("end time: ", datetime.datetime.now())
rprenger's avatar
rprenger committed
231
        
rprenger's avatar
rprenger committed
232
233
234

class MegatronServer(object):
    def __init__(self, model):
235
        self.app = Flask(__name__, static_url_path='')
rprenger's avatar
rprenger committed
236
        api = Api(self.app)
237
        api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
238
239
        
    def run(self, url): 
240
        self.app.run(url, threaded=True, debug=False)