"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "b085e06b0159c08f74d348fad3ea0a27e6a45a5e"
Unverified Commit 8407429d authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #494 from SudoSharma/patch-1

Fix indentation for unconditional generation
parents 2e153930 07154dad
...@@ -107,25 +107,25 @@ def run_model(): ...@@ -107,25 +107,25 @@ def run_model():
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text) print(text)
print("=" * 80) print("=" * 80)
if args.unconditional: if args.unconditional:
generated = 0 generated = 0
for _ in range(args.nsamples // args.batch_size): for _ in range(args.nsamples // args.batch_size):
out = sample_sequence( out = sample_sequence(
model=model, length=args.length, model=model, length=args.length,
context=None, context=None,
start_token=enc.encoder['<|endoftext|>'], start_token=enc.encoder['<|endoftext|>'],
batch_size=args.batch_size, batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device temperature=args.temperature, top_k=args.top_k, device=device
) )
out = out[:,1:].tolist() out = out[:,1:].tolist()
for i in range(args.batch_size): for i in range(args.batch_size):
generated += 1 generated += 1
text = enc.decode(out[i]) text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text) print(text)
print("=" * 80) print("=" * 80)
if args.unconditional: if args.unconditional:
break break
if __name__ == '__main__': if __name__ == '__main__':
run_model() run_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment