text_generation_server.py 5.2 KB
Newer Older
rprenger's avatar
rprenger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
import datetime
rprenger's avatar
rprenger committed
16
import torch
17
import json
18
import threading
19
from flask import Flask, request, jsonify, current_app
rprenger's avatar
rprenger committed
20
21
from flask_restful import Resource, Api
from megatron import get_args
22
from megatron.text_generation import generate_and_post_process
mshoeybi's avatar
mshoeybi committed
23

rprenger's avatar
rprenger committed
24
25

GENERATE_NUM = 0
rprenger's avatar
rprenger committed
26
lock = threading.Lock()
rprenger's avatar
rprenger committed
27
28
29
30

class MegatronGenerate(Resource):
    def __init__(self, model):
        self.model = model
31

rprenger's avatar
rprenger committed
32
33
34
    @staticmethod
    def send_do_generate():
        choice = torch.cuda.LongTensor([GENERATE_NUM])
35
        torch.distributed.broadcast(choice, 0)
36
     
rprenger's avatar
rprenger committed
37
    def put(self):
38
        args = get_args()
39
40
41
        print("request IP: " + str(request.remote_addr))
        print(json.dumps(request.get_json()),flush=True)
        print("current time: ", datetime.datetime.now())
42
43
44
       
        if not "prompts" in request.get_json():
            return "prompts argument required", 400
45
        
46
47
48
49
50
51
52
53
54
        if "max_len" in request.get_json():
            return "max_len is no longer used.  Replace with tokens_to_generate", 400
        
        if "sentences" in request.get_json():
            return "sentences is no longer used.  Replace with prompts", 400

        prompts = request.get_json()["prompts"]
        if len(prompts) > 128:
            return "Maximum number of prompts is 128", 400
55

56
57
58
59
60
61
62
        tokens_to_generate = 64  # Choosing hopefully sane default.  Full sequence is slow
        if "tokens_to_generate" in request.get_json():
            tokens_to_generate = request.get_json()["tokens_to_generate"]
            if not isinstance(tokens_to_generate, int):
                return "tokens_to_generate must be an integer greater than 0"
            if tokens_to_generate < 1:
                return "tokens_to_generate must be an integer greater than 0"
rprenger's avatar
rprenger committed
63

64
65
66
67
68
        logprobs = False
        if "logprobs" in request.get_json():
            logprobs = request.get_json()["logprobs"]
            if not isinstance(logprobs, bool):
                return "logprobs must be a boolean value"
69
        
70
        temperature = 1.0
71
72
        if "temperature" in request.get_json():
            temperature = request.get_json()["temperature"]
73
74
75
76
            if not (type(temperature) == int or type(temperature) == float):
                return "temperature must be a positive number less than or equal to 100.0"
            if not (0.0 < temperature <= 100.0):
                return "temperature must be a positive number less than or equal to 100.0"
rprenger's avatar
rprenger committed
77
        
78
        top_k = 0.0
79
80
81
82
83
84
85
        if "top_k" in request.get_json():
            top_k = request.get_json()["top_k"]
            if not (type(top_k) == int):
                return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
            if not (0 < top_k <= 1000):
                return "top_k must be equal to or greater than 0 and less than or equal to 1000"
        
86
        top_p = 0.0
87
88
89
90
        if "top_p" in request.get_json():
            top_p = request.get_json()["top_p"]
            if not (type(top_p) == float):
                return "top_p must be a positive float less than or equal to 1.0"
91
92
            if top_p > 0.0 and top_k > 0.0:
                return "cannot set both top-k and top-p samplings."
93
94
95
            if not (0 < top_p <= 1.0):
                return "top_p must be less than or equal to 1.0"
        
96
97
98
99
100
        add_BOS = False
        if "add_BOS" in request.get_json():
            add_BOS = request.get_json()["add_BOS"]
            if not isinstance(add_BOS, bool):
                return "add_BOS must be a boolean value"
rprenger's avatar
rprenger committed
101

rprenger's avatar
rprenger committed
102
103
        with lock:  # Need to get lock to keep multiple threads from hitting code
            MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
mshoeybi's avatar
mshoeybi committed
104
            response, response_seg, response_logprobs, _ = \
mshoeybi's avatar
mshoeybi committed
105
106
107
108
109
110
111
112
113
114
115
                generate_and_post_process(
                    self.model,
                    prompts=prompts,
                    tokens_to_generate=tokens_to_generate,
                    return_output_log_probs=logprobs,
                    greedy_sampling=args.greedy,
                    top_k_sampling=top_k,
                    top_p_sampling=top_p,
                    temperature=temperature,
                    add_BOS=add_BOS,
                    use_eod_token_for_early_termination=True)
rprenger's avatar
rprenger committed
116
        
117
        return jsonify({"text": response,
118
119
            "segments": response_seg,
            "logprobs": response_logprobs})
rprenger's avatar
rprenger committed
120
121
122

class MegatronServer(object):
    def __init__(self, model):
123
        self.app = Flask(__name__, static_url_path='')
rprenger's avatar
rprenger committed
124
        api = Api(self.app)
125
        api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
126
127
        
    def run(self, url): 
128
        self.app.run(url, threaded=True, debug=False)