generation_xlnet.py 785 Bytes
Newer Older
1
2
import torch
from torch.nn import functional as F
thomwolf's avatar
thomwolf committed
3
from pytorch_transformers import XLNetModel, XLNetLMHeadModel, XLNetTokenizer
4
5
6
7
8

import logging
logging.basicConfig(level=logging.INFO)

tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
9
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased', attn_type='uni')
10

11
tokens = tokenizer.encode('I am very happy')
12
13
14
for i in range(len(tokens), 20):
    mask = torch.tensor([[[0.0] * i + [1.0]]])
    logits, _ = model(torch.tensor([tokens + [0]]),
15
                    #   perm_mask=mask.expand(-1, i+1, -1),
16
17
18
19
20
                      target_mapping=mask,
                      inp_q=mask.squeeze(1))
    output = torch.multinomial(F.softmax(logits[0, 0, :]), 1)
    tokens.append(output.item())
    print(tokenizer.decode(tokens))