"awq/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "4d49ae1f75e666aa930f35e81f16d33c92543696"
Commit 6be75e2a authored by rprenger's avatar rprenger
Browse files

Fixing beam search in distributed mode

parent fd176a90
......@@ -177,7 +177,7 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
stop_token,
num_return_gen,
length_penalty]
values_float_tensor = broadcast_float_list(3, float_list=values)
values_float_tensor = broadcast_float_list(6, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
beam_size = int(values_float_tensor[1].item())
add_BOS = bool(values_float_tensor[2].item())
......
......@@ -347,10 +347,9 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = False
if mpu.is_pipeline_last_stage():
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
scores = torch.zeros(beam_size,
dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
# =============
# Run infernece
# =============
......@@ -368,9 +367,9 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
vocab_size = logits.size(2)
if mpu.is_pipeline_last_stage():
vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores
......
......@@ -24,7 +24,7 @@ from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM = 0
BEAM_NUM = 0
BEAM_NUM = 1
lock = threading.Lock()
class MegatronGenerate(Resource):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment