Commit 6be75e2a authored by rprenger's avatar rprenger
Browse files

Fixing beam search in distributed mode

parent fd176a90
...@@ -177,7 +177,7 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS= ...@@ -177,7 +177,7 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
stop_token, stop_token,
num_return_gen, num_return_gen,
length_penalty] length_penalty]
values_float_tensor = broadcast_float_list(3, float_list=values) values_float_tensor = broadcast_float_list(6, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
beam_size = int(values_float_tensor[1].item()) beam_size = int(values_float_tensor[1].item())
add_BOS = bool(values_float_tensor[2].item()) add_BOS = bool(values_float_tensor[2].item())
......
...@@ -347,10 +347,9 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -347,10 +347,9 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
beam_hyp = BeamHypotheses(beam_size, length_penalty) beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = False done = False
if mpu.is_pipeline_last_stage(): scores = torch.zeros(beam_size,
scores = torch.zeros(beam_size, dtype=torch.float32,
dtype=torch.float32, device=torch.cuda.current_device()).unsqueeze(1)
device=torch.cuda.current_device()).unsqueeze(1)
# ============= # =============
# Run infernece # Run infernece
# ============= # =============
...@@ -368,9 +367,9 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -368,9 +367,9 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# logits will be meanigful only in the last pipeline stage. # logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use) logits = forward_step(tokens2use, positions2use, attention_mask2use)
vocab_size = logits.size(2)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2) log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores new_scores = log_probs[:, -1, :] + scores
......
...@@ -24,7 +24,7 @@ from megatron.text_generation import beam_search_and_post_process ...@@ -24,7 +24,7 @@ from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM = 0 GENERATE_NUM = 0
BEAM_NUM = 0 BEAM_NUM = 1
lock = threading.Lock() lock = threading.Lock()
class MegatronGenerate(Resource): class MegatronGenerate(Resource):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment