"src/vscode:/vscode.git/clone" did not exist on "8c249d1401f12d55a59a7bdb2329b29921ae864e"
Commit de84b2af authored by rprenger's avatar rprenger
Browse files

Possibly working beam search

parent fa75238b
......@@ -16,4 +16,5 @@
from .api import (
generate,
generate_and_post_process)
generate_and_post_process,
beam_search_and_post_process)
......@@ -158,11 +158,12 @@ def beam_search_and_post_process(model,
if mpu.is_pipeline_first_stage():
lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
tokens, prompts_plus_generations, prompts_plus_generations_segments = detokenize_generations(tokens, lengths, True)
return prompts_plus_generations, prompts_plus_generations_segments, tokens
scores = scores.cpu().numpy().tolist()
return prompts_plus_generations, prompts_plus_generations_segments, scores
return None
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False)
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False):
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
beam_size,
......
......@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM = 0
BEAM_NUM = 0
lock = threading.Lock()
class MegatronGenerate(Resource):
......@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0)
@staticmethod
def send_do_beam_search():
choice = torch.cuda.LongTensor([BEAM_NUM])
torch.distributed.broadcast(choice, 0)
def put(self):
args = get_args()
......@@ -134,15 +141,39 @@ class MegatronGenerate(Resource):
if not isinstance(no_log, bool):
return "beam_search must be a boolean value"
beam_size = 4
if "beam_size" in request.get_json():
beam_size = request.get_json()["beam_size"]
if not isinstance(beam_size, int):
return "beam_size must be integer"
if beam_size < 1:
return "beam_size must be an integer > 1"
with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log:
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("start time: ", datetime.datetime.now())
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
try:
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
if beam_search:
MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \
beam_search_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size = beam_size,
add_BOS=add_BOS)
return jsonify({"text": response,
"segments": response_seg,
"scores": response_scores})
else:
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
......@@ -155,13 +186,15 @@ class MegatronGenerate(Resource):
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
random_seed=random_seed)
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now())
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
class MegatronServer(object):
def __init__(self, model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment