Commit 323e75c4 authored by Raul Puri's avatar Raul Puri
Browse files

Update generate_samples.py

parent 29d10a36
......@@ -366,12 +366,13 @@ def switch(val1, val2, boolean):
return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
if isinstance(model, DDP):
model = model.module
if isinstance(model, FP16_Module):
model = model.module
original_output_parallel = model.parallel_output
model.parallel_output = False
actual_model = model
if isinstance(actual_model, DDP):
actual_model = actual_model.module
if isinstance(actual_model, FP16_Module):
actual_model = actual_model.module
original_output_parallel = actual_model.parallel_output
actual_model.parallel_output = False
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
......@@ -438,7 +439,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
yield tokens, lengths
if done:
break
model.parallel_output = original_output_parallel
actual_model.parallel_output = original_output_parallel
def prepare_tokenizer(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment