Commit fa75238b authored by rprenger's avatar rprenger
Browse files

Almost working beam search

parent f00d0a3f
...@@ -22,7 +22,8 @@ from megatron import mpu ...@@ -22,7 +22,8 @@ from megatron import mpu
from .communication import broadcast_float_list from .communication import broadcast_float_list
from .generation import ( from .generation import (
generate_tokens_probs_and_return_on_first_stage, generate_tokens_probs_and_return_on_first_stage,
score_and_return_on_first_stage) score_and_return_on_first_stage,
beam_search_and_return_on_first_stage)
from .tokenization import ( from .tokenization import (
tokenize_prompts, tokenize_prompts,
detokenize_generations) detokenize_generations)
...@@ -138,3 +139,40 @@ def generate(model, ...@@ -138,3 +139,40 @@ def generate(model,
use_eod_token_for_early_termination=use_eod_token_for_early_termination, use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol) stop_on_eol=stop_on_eol)
def beam_search_and_post_process(model,
prompts=None,
tokens_to_generate=0,
beam_size=0,
add_BOS=False):
"""Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, scores = beam_search(model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size=beam_size,
add_BOS=add_BOS)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
tokens, prompts_plus_generations, prompts_plus_generations_segments = detokenize_generations(tokens, lengths, True)
return prompts_plus_generations, prompts_plus_generations_segments, tokens
return None
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False)
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
beam_size,
add_BOS]
values_float_tensor = broadcast_float_list(3, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
beam_size = int(values_float_tensor[1].item())
add_BOS = bool(values_float_tensor[2].item())
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor, beam_size)
...@@ -200,6 +200,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -200,6 +200,7 @@ def generate_tokens_probs_and_return_on_first_stage(
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
vocab_size=tokenizer.vocab_size) vocab_size=tokenizer.vocab_size)
# If a prompt length is smaller or equal th current context # If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens # length, it means we have started generating tokens
started = lengths <= context_length started = lengths <= context_length
...@@ -281,6 +282,74 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -281,6 +282,74 @@ def generate_tokens_probs_and_return_on_first_stage(
return tokens, generated_sequence_lengths, output_log_probs return tokens, generated_sequence_lengths, output_log_probs
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
assert(batch_size == 1)
prompt_length = lengths.item()
final_sequence_length = tokens.size(1)
final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if prompt_length >= final_sequence_length:
raise ValueError("context length + tokens_to_generate too large")
# forward step.
forward_step = ForwardStep(model, beam_size, final_sequence_length)
if mpu.is_pipeline_last_stage():
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
# =============
# Run infernece
# =============
with torch.no_grad():
tokens = tokens.repeat(beam_size, 1)
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(prompt_length, final_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
vocab_size = logits.size(2)
if mpu.is_pipeline_last_stage():
log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores
if context_length == prompt_length: # if this is the first one
sorted_scores, indices = torch.sort(new_scores[0,:], descending=True)
else:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
best_batches = torch.div(indices[:beam_size], vocab_size, rounding_mode='floor')
best_words = indices[:beam_size] % vocab_size
tokens = tokens[best_batches,:]
tokens[:, context_length] = best_words
scores = sorted_scores[:beam_size].unsqueeze(1)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
copy_from_last_to_first_pipeline_stage(scores.size(0), torch.float32,
scores[:,0])
return tokens, scores
def _build_attention_mask_and_position_ids(tokens): def _build_attention_mask_and_position_ids(tokens):
"""Build the attention mask and postition ids for the input tokens.""" """Build the attention mask and postition ids for the input tokens."""
......
...@@ -128,6 +128,12 @@ class MegatronGenerate(Resource): ...@@ -128,6 +128,12 @@ class MegatronGenerate(Resource):
if not isinstance(no_log, bool): if not isinstance(no_log, bool):
return "no_log must be a boolean value" return "no_log must be a boolean value"
beam_search = False
if "beam_search" in request.get_json():
beam_search = request.get_json()["beam_search"]
if not isinstance(no_log, bool):
return "beam_search must be a boolean value"
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log: if not no_log:
print("request IP: " + str(request.remote_addr)) print("request IP: " + str(request.remote_addr))
......
...@@ -28,6 +28,7 @@ from megatron.model import GPTModel ...@@ -28,6 +28,7 @@ from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_server import MegatronServer from megatron.text_generation_server import MegatronServer
from megatron.text_generation import generate_and_post_process from megatron.text_generation import generate_and_post_process
from megatron.text_generation import beam_search_and_post_process
import torch import torch
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
...@@ -82,3 +83,8 @@ if __name__ == "__main__": ...@@ -82,3 +83,8 @@ if __name__ == "__main__":
generate_and_post_process(model) generate_and_post_process(model)
except ValueError as ve: except ValueError as ve:
pass pass
elif choice[0].item() == 1:
try:
beam_search_and_post_process(model)
except ValueError as ve:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment