Commit 2e65bee5 authored by peng xu's avatar peng xu
Browse files

remove debug lines for printing

parent b8428a7f
......@@ -323,10 +323,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
# if mpu.is_pipeline_first_stage():
# print('-' * 40)
# print(tokens[:, context_length-5:context_length+5])
# print(context_length)
if mpu.is_pipeline_last_stage():
vocab_size = logits.size(2)
......@@ -341,10 +337,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
# print('*' * 40)
# print(best_beam_ids)
# print(best_words)
# print(context_length)
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
......@@ -369,7 +361,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
print("find all hyp exiting")
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
......@@ -379,7 +370,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# torch.distributed.barrier()
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
if done:
print("break for loop")
break
# Update the tokens on the first stage so the next input to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment