api_server.py 4.71 KB
Newer Older
rprenger's avatar
rprenger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
16
from flask import Flask, request, jsonify, current_app
rprenger's avatar
rprenger committed
17
18
19
20
21
from flask_restful import Resource, Api

from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
22
from megatron.text_generation_utils import tokenize_batch, get_token_stream
rprenger's avatar
rprenger committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

GENERATE_NUM = 0

class MegatronGenerate(Resource):
    def __init__(self, model):
        self.model = model
    
    @staticmethod
    def send_do_generate():
        choice = torch.cuda.LongTensor([GENERATE_NUM])
        torch.distributed.broadcast(choice,
                                    mpu.get_tensor_model_parallel_src_rank(),
                                    group=mpu.get_tensor_model_parallel_group())
    
    @staticmethod
    def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
        """
        Needs to be synced up with receive_generate_info
        """
        # Send the sizes of the tensors
        input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len]
        input_info_tensor = torch.cuda.LongTensor(input_info)
45
        torch.distributed.broadcast(input_info_tensor, 0)
rprenger's avatar
rprenger committed
46

47
48
49
        # Send variables to all ranks 
        torch.distributed.broadcast(context_length_tensor, 0)
        torch.distributed.broadcast(context_tokens_tensor, 0)
rprenger's avatar
rprenger committed
50
51
52
53
54
55
56

    @staticmethod
    def receive_generate_info():
        """
        Needs to be synced up with send_generate_info
        """
        input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda"))
57
        torch.distributed.broadcast(input_info_tensor, 0)
rprenger's avatar
rprenger committed
58
59
60
61
62
63
64
        batch_size = input_info_tensor[0].item()
        seq_len = input_info_tensor[1].item()
        max_len = input_info_tensor[2].item()
        
        context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda"))
        context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda"))
        
65
66
67
68
        # Send variables to all ranks 
        torch.distributed.broadcast(context_length_tensor, 0)
        torch.distributed.broadcast(context_tokens_tensor, 0)
        
rprenger's avatar
rprenger committed
69
70
71
72
        return context_length_tensor, context_tokens_tensor, max_len
    
    @staticmethod
    def do_generate(model, context_length_tensor, context_tokens_tensor, max_len):
73
        token_stream = get_token_stream(model, context_tokens_tensor, context_length_tensor)
rprenger's avatar
rprenger committed
74
75
76
77
78
79
80
        for i, decode_tokens in enumerate(token_stream):
            if i == max_len-1:
                break
            pass
        return decode_tokens
    
    def put(self):
81
        args = get_args()
rprenger's avatar
rprenger committed
82
        sentences = request.get_json()["sentences"]
83
84
85
86
        if len(sentences) > 128:
            return "Maximum number of sentences is 128", 400

        max_len = 64  # Choosing hopefully sane default.  Full sequence is slow
rprenger's avatar
rprenger committed
87
        if "max_len" in request.get_json():
88
89
90
            input_max_len = request.get_json()["max_len"]
            if input_max_len < args.seq_length:
                max_len = input_max_len
rprenger's avatar
rprenger committed
91
92
93
94
95
96
97
98

        context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
        MegatronGenerate.send_do_generate()  # Tell other ranks we're doing generate
        MegatronGenerate.send_generate_info(context_tokens_tensor, context_length_tensor, max_len)  # Send them info
        decode_tokens = MegatronGenerate.do_generate(self.model, context_length_tensor, context_tokens_tensor, max_len)  # Do stuff
        args = get_args()
        tokenizer = get_tokenizer()
        decode_tokens, _ = decode_tokens
99
100
101
102
103
        resp_sentences = []
        for i in range(decode_tokens.size(0)):
            decode_token = decode_tokens[i,:].cpu().numpy().tolist()
            resp_sentences.append(tokenizer.detokenize(decode_token))
        return jsonify({"sentences": resp_sentences})
rprenger's avatar
rprenger committed
104
105
    

106
107
108
def index():
    return current_app.send_static_file('index.html')

rprenger's avatar
rprenger committed
109
110
111
class MegatronServer(object):
    def __init__(self, model):
        self.app = Flask(__name__)
112
        self.app.add_url_rule('/', 'index', index)
rprenger's avatar
rprenger committed
113
114
115
116
117
        api = Api(self.app)
        api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])

    def run(self, url):
        self.app.run(url, debug=False)