Commit e18f786c authored by Lysandre's avatar Lysandre
Browse files

Quickstart example showcasing past

parent 155c782a
......@@ -188,3 +188,35 @@ assert predicted_text == 'Who was Jim Henson? Jim Henson was a man'
```
Examples for each model class of each model architecture (Bert, GPT, GPT-2, Transformer-XL, XLNet and XLM) can be found in the [documentation](#documentation).
#### Using the past
GPT-2 as well as some other models (GPT, XLNet, Transfo-XL, CTRL) make use of a `past` or `mems` attribute which can be used to prevent re-computing the key/value pairs when using sequential decoding. It is useful when generating sequences as a big part of the attention mechanism benefits from previous computations.
Here is a fully-working example using the `past` with `GPT2LMHeadModel` and argmax decoding (which should only be used as an example, as argmax decoding introduces a lot of repetition):
```python
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')
generated = tokenizer.encode("The Manhattan bridge")
context = torch.tensor([generated])
past = None
for i in range(100):
print(i)
output, past = model(context, past=past)
token = torch.argmax(output[0, :])
generated += [token.tolist()]
context = token.unsqueeze(0)
sequence = tokenizer.decode(generated)
print(sequence)
```
The model only requires a single token as input as all the previous tokens' key/value pairs are contained in the `past`.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment