Commit b9b6fe0d authored by Raul Puri's avatar Raul Puri
Browse files

force output gathering

parent f223ff67
...@@ -362,6 +362,12 @@ def switch(val1, val2, boolean): ...@@ -362,6 +362,12 @@ def switch(val1, val2, boolean):
return (1-boolean)*val1 + boolean*val2 return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None): def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
if isinstance(model, DDP):
model = model.module
if isinstance(model, FP16_Module):
model = model.module
original_output_parallel = model.parallel_output
model.parallel_output = False
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
...@@ -404,6 +410,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -404,6 +410,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if args.greedy: if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1) prev = torch.argmax(logits, dim=-1).view(-1)
else: else:
logits = logits.float()
logits /= args.temperature logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
...@@ -427,6 +434,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -427,6 +434,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
yield tokens, lengths yield tokens, lengths
if done: if done:
break break
model.parallel_output = original_output_parallel
def prepare_tokenizer(args): def prepare_tokenizer(args):
......
...@@ -275,6 +275,7 @@ def vocab_size_with_padding(num_tokens, args): ...@@ -275,6 +275,7 @@ def vocab_size_with_padding(num_tokens, args):
after = num_tokens after = num_tokens
multiple = args.make_vocab_size_divisible_by * \ multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size() mpu.get_model_parallel_world_size()
if multiple > 0:
while (after % multiple) != 0: while (after % multiple) != 0:
after += 1 after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy ' print_rank_0('> padded vocab (size: {}) with {} dummy '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment