"docs/vscode:/vscode.git/clone" did not exist on "655ebdbcd71251ff6bbac89c4183f537db9aae2d"
Commit dee8707e authored by mshoeybi's avatar mshoeybi
Browse files

Merge branch 'main' into inference

parents 24684cbb 5ab64637
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import datetime import datetime
import torch import torch
import json import json
import threading
from flask import Flask, request, jsonify, current_app from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
...@@ -22,11 +23,12 @@ from megatron import mpu ...@@ -22,11 +23,12 @@ from megatron import mpu
from megatron.text_generation_utils import generate from megatron.text_generation_utils import generate
GENERATE_NUM = 0 GENERATE_NUM = 0
lock = threading.Lock()
class MegatronGenerate(Resource): class MegatronGenerate(Resource):
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
@staticmethod @staticmethod
def send_do_generate(): def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM]) choice = torch.cuda.LongTensor([GENERATE_NUM])
...@@ -37,6 +39,7 @@ class MegatronGenerate(Resource): ...@@ -37,6 +39,7 @@ class MegatronGenerate(Resource):
print("request IP: " + str(request.remote_addr)) print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True) print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now()) print("current time: ", datetime.datetime.now())
sentences = request.get_json()["sentences"] sentences = request.get_json()["sentences"]
if len(sentences) > 128: if len(sentences) > 128:
return "Maximum number of sentences is 128", 400 return "Maximum number of sentences is 128", 400
...@@ -54,16 +57,24 @@ class MegatronGenerate(Resource): ...@@ -54,16 +57,24 @@ class MegatronGenerate(Resource):
all_probs = request.get_json()["all_probs"] all_probs = request.get_json()["all_probs"]
if not isinstance(all_probs, bool): if not isinstance(all_probs, bool):
return "all_probs must be a boolean value" return "all_probs must be a boolean value"
temperature = args.temperature temperature = args.temperature
if "temperature" in request.get_json(): if "temperature" in request.get_json():
temperature = request.get_json()["temperature"] temperature = request.get_json()["temperature"]
if not isinstance(temperature, float) or not \ if not isinstance(temperature, float) or not \
0.0 < temperature <= 100.0: 0.0 < temperature <= 100.0:
return "temperature must be a positive float less than or equal to 100.0" return "temperature must be a positive float less than or equal to 100.0"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate with lock: # Need to get lock to keep multiple threads from hitting code
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature) MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS)
if all_probs: if all_probs:
return jsonify({"sentences": resp_sentences, return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg, "segments": resp_sentences_seg,
...@@ -77,10 +88,9 @@ class MegatronGenerate(Resource): ...@@ -77,10 +88,9 @@ class MegatronGenerate(Resource):
class MegatronServer(object): class MegatronServer(object):
def __init__(self, model): def __init__(self, model):
self.app = Flask(__name__, static_folder='static', static_url_path='') self.app = Flask(__name__, static_url_path='')
self.app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 0
api = Api(self.app) api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
def run(self, url): def run(self, url):
self.app.run(url, threaded=True, debug=False) self.app.run(url, threaded=True, debug=False)
...@@ -95,10 +95,13 @@ def pad_batch(batch, pad_id, max_len): ...@@ -95,10 +95,13 @@ def pad_batch(batch, pad_id, max_len):
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
def tokenize_batch(sentences, max_len): def tokenize_batch(sentences, max_len, add_BOS):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
context_tokens = [tokenizer.tokenize(s) for s in sentences] if add_BOS:
context_tokens = [[tokenizer.eod] + tokenizer.tokenize(s) for s in sentences]
else:
context_tokens = [tokenizer.tokenize(s) for s in sentences]
context_tokens, context_lengths = pad_batch(context_tokens, context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, max_len) tokenizer.eod, max_len)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
...@@ -184,10 +187,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_ ...@@ -184,10 +187,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if tokens is not None: if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits return tokens[:, :context_length], output_logits, full_logits
def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0): def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, temperature=1.0, add_BOS=False):
model.eval() model.eval()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate) context_tokens_tensor, context_length_tensor = tokenize_batch(sentences, tokens_to_generate, add_BOS)
send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs) send_generate_info(context_tokens_tensor, context_length_tensor, tokens_to_generate, all_probs)
else: else:
context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info() context_length_tensor, context_tokens_tensor, tokens_to_generate, all_probs = receive_generate_info()
......
...@@ -205,7 +205,7 @@ def main(): ...@@ -205,7 +205,7 @@ def main():
args.task)) args.task))
# Set up model and load checkpoint. # Set up model and load checkpoint.
model = get_model(get_model_provider(eval_metric)) model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False)
if args.load is not None: if args.load is not None:
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment