Commit 4ad5717e authored by Raul Puri's avatar Raul Puri
Browse files

Merge branch 'model_parallel_generation' into 'master'

force output gathering

See merge request ADLR/megatron-lm!17
parents 2d76d065 7a6d630e
......@@ -362,6 +362,12 @@ def switch(val1, val2, boolean):
return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
if isinstance(model, DDP):
model = model.module
if isinstance(model, FP16_Module):
model = model.module
original_output_parallel = model.parallel_output
model.parallel_output = False
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
......@@ -404,6 +410,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
......@@ -427,6 +434,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
yield tokens, lengths
if done:
break
model.parallel_output = original_output_parallel
def prepare_tokenizer(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment