Commit 2e65bee5 authored by peng xu's avatar peng xu
Browse files

remove debug lines for printing

parent b8428a7f
...@@ -323,10 +323,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -323,10 +323,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# logits will be meanigful only in the last pipeline stage. # logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use) logits = forward_step(tokens2use, positions2use, attention_mask2use)
# if mpu.is_pipeline_first_stage():
# print('-' * 40)
# print(tokens[:, context_length-5:context_length+5])
# print(context_length)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
vocab_size = logits.size(2) vocab_size = logits.size(2)
...@@ -341,10 +337,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -341,10 +337,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long() best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size] best_scores = sorted_scores[: 2 * beam_size]
# print('*' * 40)
# print(best_beam_ids)
# print(best_words)
# print(context_length)
next_beams = [] next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate( for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
...@@ -369,7 +361,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -369,7 +361,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length): if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device()) done = torch.ones(1, dtype=torch.uint8, device=torch.cuda.current_device())
print("find all hyp exiting")
best_batches = tokens.new([item[2] for item in next_beams]) best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:] tokens = tokens[best_batches,:]
...@@ -379,7 +370,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -379,7 +370,6 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# torch.distributed.barrier() # torch.distributed.barrier()
done = broadcast_from_last_pipeline_stage(1, torch.uint8, done) done = broadcast_from_last_pipeline_stage(1, torch.uint8, done)
if done: if done:
print("break for loop")
break break
# Update the tokens on the first stage so the next input to # Update the tokens on the first stage so the next input to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment