Commit 87c8fca9 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

add example for ctrl text generation in docs

parent 88def24c
...@@ -624,6 +624,14 @@ class PreTrainedModel(nn.Module): ...@@ -624,6 +624,14 @@ class PreTrainedModel(nn.Module):
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, do_sample=True, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using beam search decoding (3 beams) outputs = model.generate(input_ids=input_ids, max_length=40, do_sample=True, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using beam search decoding (3 beams)
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences using using greedy search
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
""" """
# We cannot generate if the model does not have a LM head # We cannot generate if the model does not have a LM head
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment