Commit 5ab64637 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'add_BOS' into 'main'

Add Beginning of Sentence token option and adding semaphore while multi-threading to prevent crashes and hangs due to connection keep-alives

See merge request ADLR/megatron-lm!328
parents 14f2c684 7b293d9b
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import datetime import datetime
import torch import torch
import json import json
import threading
from flask import Flask, request, jsonify, current_app from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
...@@ -22,6 +23,7 @@ from megatron import mpu ...@@ -22,6 +23,7 @@ from megatron import mpu
from megatron.text_generation_utils import generate from megatron.text_generation_utils import generate
GENERATE_NUM = 0 GENERATE_NUM = 0
lock = threading.Lock()
class MegatronGenerate(Resource): class MegatronGenerate(Resource):
def __init__(self, model): def __init__(self, model):
...@@ -37,6 +39,7 @@ class MegatronGenerate(Resource): ...@@ -37,6 +39,7 @@ class MegatronGenerate(Resource):
print("request IP: " + str(request.remote_addr)) print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True) print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now()) print("current time: ", datetime.datetime.now())
sentences = request.get_json()["sentences"] sentences = request.get_json()["sentences"]
if len(sentences) > 128: if len(sentences) > 128:
return "Maximum number of sentences is 128", 400 return "Maximum number of sentences is 128", 400
...@@ -62,8 +65,16 @@ class MegatronGenerate(Resource): ...@@ -62,8 +65,16 @@ class MegatronGenerate(Resource):
0.0 < temperature <= 100.0: 0.0 < temperature <= 100.0:
return "temperature must be a positive float less than or equal to 100.0" return "temperature must be a positive float less than or equal to 100.0"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature) resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS)
if all_probs: if all_probs:
return jsonify({"sentences": resp_sentences, return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg, "segments": resp_sentences_seg,
...@@ -77,8 +88,7 @@ class MegatronGenerate(Resource): ...@@ -77,8 +88,7 @@ class MegatronGenerate(Resource):
class MegatronServer(object): class MegatronServer(object):
def __init__(self, model): def __init__(self, model):
self.app = Flask(__name__, static_folder='static', static_url_path='') self.app = Flask(__name__, static_url_path='')
self.app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0
api = Api(self.app) api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
......
...@@ -95,9 +95,12 @@ def pad_batch(batch, pad_id, max_len): ...@@ -95,9 +95,12 @@ def pad_batch(batch, pad_id, max_len):
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
def tokenize_batch(sentences, max_len): def tokenize_batch(sentences, max_len, add_BOS):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
if add_BOS:
context_tokens = [[tokenizer.eod] + tokenizer.tokenize(s) for s in sentences]
else:
context_tokens = [tokenizer.tokenize(s) for s in sentences] context_tokens = [tokenizer.tokenize(s) for s in sentences]
context_tokens, context_lengths = pad_batch(context_tokens, context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, max_len) tokenizer.eod, max_len)
...@@ -175,10 +178,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ ...@@ -175,10 +178,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if tokens is not None: if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits, full_logits
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False):
model.eval() model.eval()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
else: else:
context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment