Commit 69757f9a authored by rprenger's avatar rprenger
Browse files

Adding the option for beginning of sentence token (and fixing hangs)

parent b46482e8
......@@ -15,6 +15,7 @@
import datetime
import torch
import json
import threading
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
......@@ -22,11 +23,12 @@ from megatron import mpu
from megatron.text_generation_utils import generate
GENERATE_NUM = 0
sem = threading.Semaphore()
class MegatronGenerate(Resource):
def __init__(self, model):
self.model = model
@staticmethod
def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM])
......@@ -37,6 +39,7 @@ class MegatronGenerate(Resource):
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now())
sentences = request.get_json()["sentences"]
if len(sentences) > 128:
return "Maximum number of sentences is 128", 400
......@@ -54,9 +57,18 @@ class MegatronGenerate(Resource):
all_probs = request.get_json()["all_probs"]
if not isinstance(all_probs, bool):
return "all_probs must be a boolean value"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
sem.acquire() # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs)
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, add_BOS)
sem.release()
if all_probs:
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
......@@ -70,10 +82,9 @@ class MegatronGenerate(Resource):
class MegatronServer(object):
def __init__(self, model):
self.app = Flask(__name__, static_folder='static', static_url_path='')
self.app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0
self.app = Flask(__name__)
api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
def run(self, url):
def run(self, url):
self.app.run(url, threaded=True, debug=False)
......@@ -95,10 +95,13 @@ def pad_batch(batch, pad_id, max_len):
context_lengths.append(context_length)
return batch, context_lengths
def tokenize_batch(sentences, max_len):
def tokenize_batch(sentences, max_len, add_BOS):
args = get_args()
tokenizer = get_tokenizer()
context_tokens = [tokenizer.tokenize(s) for s in sentences]
if add_BOS:
context_tokens = [[tokenizer.eod] + tokenizer.tokenize(s) for s in sentences]
else:
context_tokens = [tokenizer.tokenize(s) for s in sentences]
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, max_len)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
......@@ -174,10 +177,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False):
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, add_BOS=False):
model.eval()
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate)
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
else:
context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment