Commit 323e75c4 authored by Raul Puri's avatar Raul Puri
Browse files

Update generate_samples.py

parent 29d10a36
...@@ -366,12 +366,13 @@ def switch(val1, val2, boolean): ...@@ -366,12 +366,13 @@ def switch(val1, val2, boolean):
return (1-boolean)*val1 + boolean*val2 return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None): def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
if isinstance(model, DDP): actual_model = model
model = model.module if isinstance(actual_model, DDP):
if isinstance(model, FP16_Module): actual_model = actual_model.module
model = model.module if isinstance(actual_model, FP16_Module):
original_output_parallel = model.parallel_output actual_model = actual_model.module
model.parallel_output = False original_output_parallel = actual_model.parallel_output
actual_model.parallel_output = False
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
...@@ -438,7 +439,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -438,7 +439,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
yield tokens, lengths yield tokens, lengths
if done: if done:
break break
model.parallel_output = original_output_parallel actual_model.parallel_output = original_output_parallel
def prepare_tokenizer(args): def prepare_tokenizer(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment