Commit 29dd0a35 authored by rprenger's avatar rprenger
Browse files

Refactoring code so server code is more independent of sampling and adding a...

Refactoring code so server code is more independent of sampling and adding a CLI.  CLI still has URL of server hard-coded
parent 61184a8f
......@@ -17,9 +17,8 @@ from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.text_generation_utils import tokenize_batch, get_token_stream
from megatron.text_generation_utils import generate
GENERATE_NUM = 0
......@@ -33,50 +32,7 @@ class MegatronGenerate(Resource):
torch.distributed.broadcast(choice,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
@staticmethod
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len]
input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.broadcast(input_info_tensor, 0)
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
@staticmethod
def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda"))
torch.distributed.broadcast(input_info_tensor, 0)
batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item()
max_len = input_info_tensor[2].item()
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda"))
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda"))
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
return context_length_tensor, context_tokens_tensor, max_len
@staticmethod
def do_generate(model, context_length_tensor, context_tokens_tensor, max_len):
token_stream = get_token_stream(model, context_tokens_tensor, context_length_tensor)
for i, decode_tokens in enumerate(token_stream):
if i == max_len-1:
break
pass
return decode_tokens
def put(self):
args = get_args()
sentences = request.get_json()["sentences"]
......@@ -89,19 +45,10 @@ class MegatronGenerate(Resource):
if input_max_len < args.seq_length:
max_len = input_max_len
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
MegatronGenerate.send_generate_info(context_tokens_tensor, context_length_tensor, max_len) # Send them info
decode_tokens = MegatronGenerate.do_generate(self.model, context_length_tensor, context_tokens_tensor, max_len) # Do stuff
args = get_args()
tokenizer = get_tokenizer()
decode_tokens, _ = decode_tokens
resp_sentences = []
for i in range(decode_tokens.size(0)):
decode_token = decode_tokens[i,:].cpu().numpy().tolist()
resp_sentences.append(tokenizer.detokenize(decode_token))
resp_sentences = generate(self.model, sentences, max_len)
return jsonify({"sentences": resp_sentences})
def index():
return current_app.send_static_file('index.html')
......
......@@ -118,8 +118,67 @@ def get_token_stream(model, context_tokens_tensor, context_length_tensor):
else:
yield None, None
def switch(val1, val2, boolean):
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len]
input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.broadcast(input_info_tensor, 0)
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda"))
torch.distributed.broadcast(input_info_tensor, 0)
batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item()
max_len = input_info_tensor[2].item()
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda"))
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda"))
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
return context_length_tensor, context_tokens_tensor, max_len
def synced_generate(model, context_length_tensor, context_tokens_tensor, max_len):
token_stream = get_token_stream(model, context_tokens_tensor, context_length_tensor)
for i, decode_tokens in enumerate(token_stream):
if i == max_len-1:
break
pass
return decode_tokens
def generate(model, sentences=None, max_len=0):
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
else:
context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
decode_tokens = synced_generate(model, context_length_tensor, context_tokens_tensor, max_len)
if torch.distributed.get_rank() == 0:
args = get_args()
tokenizer = get_tokenizer()
decode_tokens, _ = decode_tokens
resp_sentences = []
for i in range(decode_tokens.size(0)):
decode_token = decode_tokens[i,:].cpu().numpy().tolist()
resp_sentences.append(tokenizer.detokenize(decode_token))
return resp_sentences
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
......
......@@ -21,19 +21,15 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
import socket
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.api_server import MegatronServer, MegatronGenerate
from megatron.api_server import MegatronServer
from megatron.text_generation_utils import generate
import torch
def do_generate(model):
context_length_tensor, context_tokens_tensor, max_len = MegatronGenerate.receive_generate_info()
MegatronGenerate.do_generate(model, context_length_tensor, context_tokens_tensor, max_len)
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
......@@ -86,4 +82,4 @@ if __name__ == "__main__":
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
if choice[0].item() == 0:
do_generate(model)
generate(model)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import urllib2
class PutRequest(urllib2.Request):
'''class to handling putting with urllib2'''
def get_method(self, *args, **kwargs):
return 'PUT'
if __name__ == "__main__":
while True:
sentence = raw_input("Enter prompt: ")
max_len = int(input("Enter number tokens output: "))
data = json.dumps({"sentences": [sentence], "max_len":max_len})
req = PutRequest("http://sc-sdgx2-484:5000/generate", data, {'Content-Type': 'application/json'})
response = urllib2.urlopen(req)
resp_sentences = json.load(response)
print("Megatron Response: ")
print(resp_sentences["sentences"][0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment