Unverified Commit 6b35cfd2 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #423 from dhanajitb/master

making unconditional generation work
parents 7e7e4753 aff44f0c
......@@ -91,16 +91,33 @@ def run_model():
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=context_tokens,
start_token=None,
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:, len(context_tokens):].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
if args.unconditional:
generated = 0
for _ in range(args.nsamples // args.batch_size):
out = sample_sequence(
model=model, length=args.length,
context=context_tokens if not args.unconditional else None,
start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None,
context=None,
start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:, len(context_tokens):].tolist()
out = out[:,1:].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
......@@ -113,3 +130,4 @@ def run_model():
if __name__ == '__main__':
run_model()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment