Commit 2cdebf4b authored by rprenger's avatar rprenger
Browse files

Working single sentence version

parent 83c4d95a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from flask import Flask, request, jsonify
from flask_restful import Resource, Api
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.text_generation_utils import pad_batch
from megatron.text_generation_utils import get_token_stream2
GENERATE_NUM = 0
def tokenize_batch(sentences):
args = get_args()
tokenizer = get_tokenizer()
context_tokens = [tokenizer.tokenize(s) for s in sentences]
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor
class MegatronGenerate(Resource):
def __init__(self, model):
self.model = model
@staticmethod
def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
@staticmethod
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len]
input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.broadcast(input_info_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
# Now send tensors
torch.distributed.broadcast(context_length_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
@staticmethod
def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.device("cuda"))
torch.distributed.broadcast(input_info_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item()
max_len = input_info_tensor[2].item()
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda"))
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda"))
torch.distributed.broadcast(context_length_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
return context_length_tensor, context_tokens_tensor, max_len
@staticmethod
def do_generate(model, context_length_tensor, context_tokens_tensor, max_len):
token_stream = get_token_stream2(model, context_tokens_tensor, context_length_tensor)
for i, decode_tokens in enumerate(token_stream):
if i == max_len-1:
break
pass
return decode_tokens
def put(self):
sentences = request.get_json()["sentences"]
max_len = 1024 # TODO (rprenger) this should not be hardcoded
if "max_len" in request.get_json():
max_len = request.get_json()["max_len"]
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
MegatronGenerate.send_generate_info(context_tokens_tensor, context_length_tensor, max_len) # Send them info
decode_tokens = MegatronGenerate.do_generate(self.model, context_length_tensor, context_tokens_tensor, max_len) # Do stuff
args = get_args()
tokenizer = get_tokenizer()
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(decode_tokens)
return jsonify({"sentences": [trim_decode_tokens]})
class MegatronServer(object):
def __init__(self, model):
self.app = Flask(__name__)
api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
def run(self, url):
self.app.run(url, debug=False)
...@@ -387,6 +387,19 @@ def pad_batch(batch, pad_id, args): ...@@ -387,6 +387,19 @@ def pad_batch(batch, pad_id, args):
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
def get_token_stream2(model, context_tokens_tensor, context_length_tensor):
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def get_token_stream(model, context_tokens): def get_token_stream(model, context_tokens):
...@@ -406,18 +419,7 @@ def get_token_stream(model, context_tokens): ...@@ -406,18 +419,7 @@ def get_token_stream(model, context_tokens):
mpu.get_tensor_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
context_length = context_length_tensor.min().item() return get_token_stream2(model, context_tokens_tensor, context_length_tensor)
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import socket
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_utils import generate_samples_interactive
from megatron.api_server import MegatronServer
from megatron.api_server import MegatronGenerate
import torch
def do_generate(model):
context_length_tensor, context_tokens_tensor, max_len = MegatronGenerate.receive_generate_info()
MegatronGenerate.do_generate(model, context_length_tensor, context_tokens_tensor, max_len)
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process)
return model
def add_text_generate_args(parser):
"""Text generation arguments."""
group = parser.add_argument_group(title='text generation')
group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='Number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='Output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='During generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
if __name__ == "__main__":
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
server = MegatronServer(model)
server.run("0.0.0.0")
while True:
choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
print("got: "+str(choice[0].item()))
if choice[0].item() == 0:
do_generate(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment