run_gpt2_interactive_conditional_samples.py 3.76 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
#!/usr/bin/env python3

import argparse
import logging
thomwolf's avatar
thomwolf committed
5
from tqdm import trange
thomwolf's avatar
thomwolf committed
6
7

import torch
thomwolf's avatar
thomwolf committed
8
import torch.nn.functional as F
thomwolf's avatar
thomwolf committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

def top_k_logits(logits, k):
    if k == 0:
        return logits
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1]
    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)

def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
thomwolf's avatar
thomwolf committed
28
        context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
thomwolf's avatar
thomwolf committed
29
30
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
thomwolf's avatar
thomwolf committed
31
        context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
thomwolf's avatar
thomwolf committed
32
33
    prev = context
    output = context
thomwolf's avatar
thomwolf committed
34
    past = None
thomwolf's avatar
thomwolf committed
35
    with torch.no_grad():
thomwolf's avatar
thomwolf committed
36
        for i in trange(length):
thomwolf's avatar
thomwolf committed
37
38
39
            logits, past = model(prev, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_k_logits(logits, k=top_k)
thomwolf's avatar
thomwolf committed
40
41
            log_probs = F.softmax(logits, dim=-1)
            prev = torch.multinomial(log_probs, num_samples=1)
thomwolf's avatar
thomwolf committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
            output = torch.cat((output, prev), dim=1)
    return output

def interact_model():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--nsamples", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=-1)
    parser.add_argument("--length", type=int, default=-1)
    parser.add_argument("--temperature", type=int, default=1)
    parser.add_argument("--top_k", type=int, default=0)
    args = parser.parse_args()
    print(args)

thomwolf's avatar
thomwolf committed
57
    if args.batch_size == -1:
thomwolf's avatar
thomwolf committed
58
59
60
61
62
63
64
65
66
67
        args.batch_size = 1
    assert args.nsamples % args.batch_size == 0

    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
thomwolf's avatar
thomwolf committed
68
69
    model.to(device)
    model.eval()
thomwolf's avatar
thomwolf committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    if args.length == -1:
        args.length = model.config.n_ctx // 2
    elif args.length > model.config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)

    while True:
        raw_text = input("Model prompt >>> ")
        while not raw_text:
            print('Prompt should not be empty!')
            raw_text = input("Model prompt >>> ")
        context_tokens = enc.encode(raw_text)
        generated = 0
        for _ in range(args.nsamples // args.batch_size):
            out = sample_sequence(
                model=model, length=args.length,
                context=context_tokens,
                batch_size=args.batch_size,
                temperature=args.temperature, top_k=args.top_k, device=device
            )
thomwolf's avatar
thomwolf committed
90
            out = out[:, len(context_tokens):].tolist()
thomwolf's avatar
thomwolf committed
91
92
93
94
95
96
97
98
99
100
            for i in range(args.batch_size):
                generated += 1
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
        print("=" * 80)

if __name__ == '__main__':
    interact_model()