"vscode:/vscode.git/clone" did not exist on "62f99e08b364aee76b146109228ba05a0bed9bc1"
run_gpt2.py 5.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
#!/usr/bin/env python3

import argparse
import logging
thomwolf's avatar
thomwolf committed
5
from tqdm import trange
thomwolf's avatar
thomwolf committed
6
7

import torch
thomwolf's avatar
thomwolf committed
8
import torch.nn.functional as F
thomwolf's avatar
thomwolf committed
9
10
11
12
13
14
15
16
17
18
import numpy as np

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

def top_k_logits(logits, k):
Catalin Voss's avatar
Catalin Voss committed
19
20
21
22
23
    """
    Masks everything but the k top entries as -infinity (1e10).
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
thomwolf's avatar
thomwolf committed
24
25
    if k == 0:
        return logits
Catalin Voss's avatar
Catalin Voss committed
26
27
28
29
    else:
        values = torch.topk(logits, k)[0]
        batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
        return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)
thomwolf's avatar
thomwolf committed
30

thomwolf's avatar
thomwolf committed
31
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
thomwolf's avatar
thomwolf committed
32
33
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
thomwolf's avatar
thomwolf committed
34
        context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
thomwolf's avatar
thomwolf committed
35
36
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
thomwolf's avatar
thomwolf committed
37
        context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
thomwolf's avatar
thomwolf committed
38
39
    prev = context
    output = context
thomwolf's avatar
thomwolf committed
40
    past = None
thomwolf's avatar
thomwolf committed
41
    with torch.no_grad():
thomwolf's avatar
thomwolf committed
42
        for i in trange(length):
thomwolf's avatar
thomwolf committed
43
44
45
            logits, past = model(prev, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_k_logits(logits, k=top_k)
thomwolf's avatar
thomwolf committed
46
            log_probs = F.softmax(logits, dim=-1)
thomwolf's avatar
thomwolf committed
47
48
49
50
            if sample:
                prev = torch.multinomial(log_probs, num_samples=1)
            else:
                _, prev = torch.topk(log_probs, k=1, dim=-1)
thomwolf's avatar
thomwolf committed
51
52
53
            output = torch.cat((output, prev), dim=1)
    return output

thomwolf's avatar
thomwolf committed
54
def run_model():
thomwolf's avatar
thomwolf committed
55
56
57
58
59
60
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--nsamples", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=-1)
    parser.add_argument("--length", type=int, default=-1)
61
    parser.add_argument("--temperature", type=float, default=1.0)
thomwolf's avatar
thomwolf committed
62
    parser.add_argument("--top_k", type=int, default=0)
thomwolf's avatar
thomwolf committed
63
    parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
thomwolf's avatar
thomwolf committed
64
65
66
    args = parser.parse_args()
    print(args)

thomwolf's avatar
thomwolf committed
67
    if args.batch_size == -1:
thomwolf's avatar
thomwolf committed
68
69
70
71
72
73
74
75
76
77
        args.batch_size = 1
    assert args.nsamples % args.batch_size == 0

    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
thomwolf's avatar
thomwolf committed
78
79
    model.to(device)
    model.eval()
thomwolf's avatar
thomwolf committed
80
81
82
83
84
85

    if args.length == -1:
        args.length = model.config.n_ctx // 2
    elif args.length > model.config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)

Benjamin Mann's avatar
Benjamin Mann committed
86
87
    while True:
        context_tokens = []
thomwolf's avatar
thomwolf committed
88
        if not args.unconditional:
thomwolf's avatar
thomwolf committed
89
            raw_text = input("Model prompt >>> ")
thomwolf's avatar
thomwolf committed
90
91
92
93
            while not raw_text:
                print('Prompt should not be empty!')
                raw_text = input("Model prompt >>> ")
            context_tokens = enc.encode(raw_text)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            generated = 0
            for _ in range(args.nsamples // args.batch_size):
                out = sample_sequence(
                    model=model, length=args.length,
                    context=context_tokens,
                    start_token=None,
                    batch_size=args.batch_size,
                    temperature=args.temperature, top_k=args.top_k, device=device
                )
                out = out[:, len(context_tokens):].tolist()
                for i in range(args.batch_size):
                    generated += 1
                    text = enc.decode(out[i])
                    print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                    print(text)
            print("=" * 80)
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        else:
            generated = 0
            for _ in range(args.nsamples // args.batch_size):
                out = sample_sequence(
                    model=model, length=args.length,
                    context=None,
                    start_token=enc.encoder['<|endoftext|>'],
                    batch_size=args.batch_size,
                    temperature=args.temperature, top_k=args.top_k, device=device
                )
                out = out[:,1:].tolist()
                for i in range(args.batch_size):
                    generated += 1
                    text = enc.decode(out[i])
                    print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                    print(text)
            print("=" * 80)
thomwolf's avatar
thomwolf committed
127
128

if __name__ == '__main__':
thomwolf's avatar
thomwolf committed
129
    run_model()
Elon Musk's avatar
Elon Musk committed
130

131