run_pplm_discrim_train.py 19.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#! /usr/bin/env python3
# coding=utf-8

# This code is licensed under a non-commercial license.

import argparse
import csv
import json
import math
import time

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim
import torch.optim as optim
import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer
from torchtext import data as torchtext_data
from torchtext import datasets
piero's avatar
piero committed
21
from tqdm import tqdm, trange
22

23
from transformers import GPT2Tokenizer, GPT2LMHeadModel
24
from pplm_classification_head import ClassificationHead
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

torch.manual_seed(0)
np.random.seed(0)
EPSILON = 1e-10
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq = 100




class Discriminator(torch.nn.Module):
    """Transformer encoder followed by a Classification Head"""

    def __init__(
            self,
            class_size,
            pretrained_model="gpt2-medium",
w4nderlust's avatar
w4nderlust committed
42
43
            cached_mode=False,
            device='cpu'
44
45
46
47
48
49
50
51
52
53
    ):
        super(Discriminator, self).__init__()
        self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
        self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
        self.embed_size = self.encoder.transformer.config.hidden_size
        self.classifier_head = ClassificationHead(
            class_size=class_size,
            embed_size=self.embed_size
        )
        self.cached_mode = cached_mode
w4nderlust's avatar
w4nderlust committed
54
        self.device = device
55
56
57
58
59
60
61
62
63
64
65
66

    def get_classifier(self):
        return self.classifier_head

    def train_custom(self):
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.classifier_head.train()

    def avg_representation(self, x):
        mask = x.ne(0).unsqueeze(2).repeat(
            1, 1, self.embed_size
w4nderlust's avatar
w4nderlust committed
67
        ).float().to(self.device).detach()
68
69
70
71
72
73
74
75
76
        hidden, _ = self.encoder.transformer(x)
        masked_hidden = hidden * mask
        avg_hidden = torch.sum(masked_hidden, dim=1) / (
                torch.sum(mask, dim=1).detach() + EPSILON
        )
        return avg_hidden

    def forward(self, x):
        if self.cached_mode:
w4nderlust's avatar
w4nderlust committed
77
            avg_hidden = x.to(self.device)
78
        else:
w4nderlust's avatar
w4nderlust committed
79
            avg_hidden = self.avg_representation(x.to(self.device))
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

        logits = self.classifier_head(avg_hidden)
        probs = F.log_softmax(logits, dim=-1)

        return probs


class Dataset(data.Dataset):
    def __init__(self, X, y):
        """Reads source and target sequences from txt files."""
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        data = {}
piero's avatar
piero committed
99
100
        data["X"] = self.X[index]
        data["y"] = self.y[index]
101
102
103
104
105
106
107
108
109
110
        return data


def collate_fn(data):
    def pad_sequences(sequences):
        lengths = [len(seq) for seq in sequences]

        padded_sequences = torch.zeros(
            len(sequences),
            max(lengths)
piero's avatar
piero committed
111
        ).long()  # padding value = 0
112
113
114
115
116
117
118
119
120
121
122

        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_sequences[i, :end] = seq[:end]

        return padded_sequences, lengths

    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

piero's avatar
piero committed
123
124
    x_batch, _ = pad_sequences(item_info["X"])
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)
125
126
127
128
129
130
131
132
133

    return x_batch, y_batch


def cached_collate_fn(data):
    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

piero's avatar
piero committed
134
135
    x_batch = torch.cat(item_info["X"], 0)
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)
136
137
138
139
140

    return x_batch, y_batch


def train_epoch(data_loader, discriminator, optimizer,
w4nderlust's avatar
w4nderlust committed
141
                epoch=0, log_interval=10, device='cpu'):
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    samples_so_far = 0
    discriminator.train_custom()
    for batch_idx, (input_t, target_t) in enumerate(data_loader):
        input_t, target_t = input_t.to(device), target_t.to(device)

        optimizer.zero_grad()

        output_t = discriminator(input_t)
        loss = F.nll_loss(output_t, target_t)
        loss.backward(retain_graph=True)
        optimizer.step()

        samples_so_far += len(input_t)

        if batch_idx % log_interval == 0:
            print(
piero's avatar
piero committed
158
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
159
160
161
162
163
164
165
                    epoch + 1,
                    samples_so_far, len(data_loader.dataset),
                    100 * samples_so_far / len(data_loader.dataset), loss.item()
                )
            )


w4nderlust's avatar
w4nderlust committed
166
def evaluate_performance(data_loader, discriminator, device='cpu'):
167
168
169
170
171
172
173
174
    discriminator.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for input_t, target_t in data_loader:
            input_t, target_t = input_t.to(device), target_t.to(device)
            output_t = discriminator(input_t)
            # sum up batch loss
piero's avatar
piero committed
175
            test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
176
177
178
179
180
181
182
            # get the index of the max log-probability
            pred_t = output_t.argmax(dim=1, keepdim=True)
            correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()

    test_loss /= len(data_loader.dataset)

    print(
piero's avatar
piero committed
183
184
        "Performance on test set: "
        "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
185
186
187
188
189
190
            test_loss, correct, len(data_loader.dataset),
            100. * correct / len(data_loader.dataset)
        )
    )


w4nderlust's avatar
w4nderlust committed
191
def predict(input_sentence, model, classes, cached=False, device='cpu'):
192
    input_t = model.tokenizer.encode(input_sentence)
193
    input_t = torch.tensor([input_t], dtype=torch.long, device=device)
194
195
196
197
    if cached:
        input_t = model.avg_representation(input_t)

    log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
piero's avatar
piero committed
198
199
    print("Input sentence:", input_sentence)
    print("Predictions:", ", ".join(
200
201
202
203
204
        "{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
        zip(classes, log_probs)
    ))


w4nderlust's avatar
w4nderlust committed
205
206
def get_cached_data_loader(dataset, batch_size, discriminator,
                           shuffle=False, device='cpu'):
207
208
209
210
211
212
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              collate_fn=collate_fn)

    xs = []
    ys = []
piero's avatar
piero committed
213
    for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        with torch.no_grad():
            x = x.to(device)
            avg_rep = discriminator.avg_representation(x).cpu().detach()
            avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
            xs += avg_rep_list
            ys += y.cpu().numpy().tolist()

    data_loader = torch.utils.data.DataLoader(
        dataset=Dataset(xs, ys),
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=cached_collate_fn)

    return data_loader


def train_discriminator(
piero's avatar
piero committed
231
        dataset, dataset_fp=None, pretrained_model="gpt2-medium",
232
        epochs=10, batch_size=64, log_interval=10,
piero's avatar
piero committed
233
234
        save_model=False, cached=False, no_cuda=False):
    device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
235

piero's avatar
piero committed
236
    print("Preprocessing {} dataset...".format(dataset))
237
238
    start = time.time()

piero's avatar
piero committed
239
    if dataset == "SST":
240
241
242
243
244
245
246
        idx2class = ["positive", "negative", "very positive", "very negative",
                     "neutral"]
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
            class_size=len(idx2class),
            pretrained_model=pretrained_model,
w4nderlust's avatar
w4nderlust committed
247
248
            cached_mode=cached,
            device=device
249
250
251
252
253
254
255
256
257
258
259
260
261
        ).to(device)

        text = torchtext_data.Field()
        label = torchtext_data.Field(sequential=False)
        train_data, val_data, test_data = datasets.SST.splits(
            text,
            label,
            fine_grained=True,
            train_subtrees=True,
        )

        x = []
        y = []
piero's avatar
piero committed
262
        for i in trange(len(train_data), ascii=True):
263
264
265
266
267
268
269
270
271
272
273
            seq = TreebankWordDetokenizer().detokenize(
                vars(train_data[i])["text"]
            )
            seq = discriminator.tokenizer.encode(seq)
            seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
            x.append(seq)
            y.append(class2idx[vars(train_data[i])["label"]])
        train_dataset = Dataset(x, y)

        test_x = []
        test_y = []
piero's avatar
piero committed
274
        for i in trange(len(test_data), ascii=True):
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
            seq = TreebankWordDetokenizer().detokenize(
                vars(test_data[i])["text"]
            )
            seq = discriminator.tokenizer.encode(seq)
            seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
            test_x.append(seq)
            test_y.append(class2idx[vars(test_data[i])["label"]])
        test_dataset = Dataset(test_x, test_y)

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 2,
        }

piero's avatar
piero committed
292
    elif dataset == "clickbait":
293
294
295
296
297
298
        idx2class = ["non_clickbait", "clickbait"]
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
            class_size=len(idx2class),
            pretrained_model=pretrained_model,
w4nderlust's avatar
w4nderlust committed
299
300
            cached_mode=cached,
            device=device
301
302
303
304
305
306
307
308
        ).to(device)

        with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
            data = []
            for i, line in enumerate(f):
                try:
                    data.append(eval(line))
                except:
piero's avatar
piero committed
309
                    print("Error evaluating line {}: {}".format(
310
311
312
313
314
                        i, line
                    ))
                    continue
        x = []
        y = []
piero's avatar
piero committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
            for i, line in enumerate(tqdm(f, ascii=True)):
                try:
                    d = eval(line)
                    seq = discriminator.tokenizer.encode(d["text"])

                    if len(seq) < max_length_seq:
                        seq = torch.tensor(
                            [50256] + seq, device=device, dtype=torch.long
                        )
                    else:
                        print("Line {} is longer than maximum length {}".format(
                            i, max_length_seq
                        ))
                        continue
                    x.append(seq)
                    y.append(d["label"])
                except:
                    print("Error evaluating / tokenizing"
                          " line {}, skipping it".format(i))
                    pass
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

        full_dataset = Dataset(x, y)
        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
        train_dataset, test_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, test_size]
        )

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 1,
        }

piero's avatar
piero committed
352
    elif dataset == "toxic":
353
354
355
356
357
358
        idx2class = ["non_toxic", "toxic"]
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
            class_size=len(idx2class),
            pretrained_model=pretrained_model,
w4nderlust's avatar
w4nderlust committed
359
360
            cached_mode=cached,
            device=device
361
362
        ).to(device)

piero's avatar
piero committed
363
364
        x = []
        y = []
365
        with open("datasets/toxic/toxic_train.txt") as f:
piero's avatar
piero committed
366
            for i, line in enumerate(tqdm(f, ascii=True)):
367
                try:
piero's avatar
piero committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
                    d = eval(line)
                    seq = discriminator.tokenizer.encode(d["text"])

                    if len(seq) < max_length_seq:
                        seq = torch.tensor(
                            [50256] + seq, device=device, dtype=torch.long
                        )
                    else:
                        print("Line {} is longer than maximum length {}".format(
                            i, max_length_seq
                        ))
                        continue
                    x.append(seq)
                    y.append(int(np.sum(d["label"]) > 0))
382
                except:
piero's avatar
piero committed
383
384
385
                    print("Error evaluating / tokenizing"
                          " line {}, skipping it".format(i))
                    pass
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

        full_dataset = Dataset(x, y)
        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
        train_dataset, test_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, test_size]
        )

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 0,
        }

piero's avatar
piero committed
402
    else:  # if dataset == "generic":
403
404
405
406
        # This assumes the input dataset is a TSV with the following structure:
        # class \t text

        if dataset_fp is None:
piero's avatar
piero committed
407
408
            raise ValueError("When generic dataset is selected, "
                             "dataset_fp needs to be specified aswell.")
409
410
411

        classes = set()
        with open(dataset_fp) as f:
piero's avatar
piero committed
412
413
            csv_reader = csv.reader(f, delimiter="\t")
            for row in tqdm(csv_reader, ascii=True):
414
415
                if row:
                    classes.add(row[0])
416
417
418
419
420
421
422

        idx2class = sorted(classes)
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
            class_size=len(idx2class),
            pretrained_model=pretrained_model,
w4nderlust's avatar
w4nderlust committed
423
424
            cached_mode=cached,
            device=device
425
426
427
428
429
        ).to(device)

        x = []
        y = []
        with open(dataset_fp) as f:
piero's avatar
piero committed
430
431
            csv_reader = csv.reader(f, delimiter="\t")
            for i, row in enumerate(tqdm(csv_reader, ascii=True)):
432
433
434
435
436
437
438
439
440
441
442
443
444
445
                if row:
                    label = row[0]
                    text = row[1]

                    try:
                        seq = discriminator.tokenizer.encode(text)
                        if (len(seq) < max_length_seq):
                            seq = torch.tensor(
                                [50256] + seq,
                                device=device,
                                dtype=torch.long
                            )

                        else:
piero's avatar
piero committed
446
447
448
449
                            print(
                                "Line {} is longer than maximum length {}".format(
                                    i, max_length_seq
                                ))
450
451
452
453
454
455
456
457
                            continue

                        x.append(seq)
                        y.append(class2idx[label])

                    except:
                        print("Error tokenizing line {}, skipping it".format(i))
                        pass
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475

        full_dataset = Dataset(x, y)
        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
        train_dataset, test_dataset = torch.utils.data.random_split(
            full_dataset,
            [train_size, test_size]
        )

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 0,
        }

    end = time.time()
piero's avatar
piero committed
476
    print("Preprocessed {} data points".format(
477
478
479
480
481
        len(train_dataset) + len(test_dataset))
    )
    print("Data preprocessing took: {:.3f}s".format(end - start))

    if cached:
piero's avatar
piero committed
482
483
        print("Building representation cache...")

484
485
486
        start = time.time()

        train_loader = get_cached_data_loader(
w4nderlust's avatar
w4nderlust committed
487
488
            train_dataset, batch_size, discriminator,
            shuffle=True, device=device
489
490
491
        )

        test_loader = get_cached_data_loader(
w4nderlust's avatar
w4nderlust committed
492
            test_dataset, batch_size, discriminator, device=device
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        )

        end = time.time()
        print("Building representation cache took: {:.3f}s".format(end - start))

    else:
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   collate_fn=collate_fn)
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                  batch_size=batch_size,
                                                  collate_fn=collate_fn)

    if save_model:
        with open("{}_classifier_head_meta.json".format(dataset),
                  "w") as meta_file:
            json.dump(discriminator_meta, meta_file)

    optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)

    for epoch in range(epochs):
        start = time.time()
piero's avatar
piero committed
516
        print("\nEpoch", epoch + 1)
517
518
519
520
521
522

        train_epoch(
            discriminator=discriminator,
            data_loader=train_loader,
            optimizer=optimizer,
            epoch=epoch,
w4nderlust's avatar
w4nderlust committed
523
524
            log_interval=log_interval,
            device=device
525
526
527
        )
        evaluate_performance(
            data_loader=test_loader,
w4nderlust's avatar
w4nderlust committed
528
529
            discriminator=discriminator,
            device=device
530
531
532
533
534
535
        )

        end = time.time()
        print("Epoch took: {:.3f}s".format(end - start))

        print("\nExample prediction")
w4nderlust's avatar
w4nderlust committed
536
537
        predict(example_sentence, discriminator, idx2class,
                cached=cached, device=device)
538
539
540
541

        if save_model:
            # torch.save(discriminator.state_dict(),
            #           "{}_discriminator_{}.pt".format(
542
            #               args.dataset, epoch + 1
543
544
            #               ))
            torch.save(discriminator.get_classifier().state_dict(),
545
546
                       "{}_classifier_head_epoch_{}.pt".format(dataset,
                                                               epoch + 1))
547
548


piero's avatar
piero committed
549
if __name__ == "__main__":
550
    parser = argparse.ArgumentParser(
piero's avatar
piero committed
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        description="Train a discriminator on top of GPT-2 representations")
    parser.add_argument("--dataset", type=str, default="SST",
                        choices=("SST", "clickbait", "toxic", "generic"),
                        help="dataset to train the discriminator on."
                             "In case of generic, the dataset is expected"
                             "to be a TSBV file with structure: class \\t text")
    parser.add_argument("--dataset_fp", type=str, default="",
                        help="File path of the dataset to use. "
                             "Needed only in case of generic datadset")
    parser.add_argument("--pretrained_model", type=str, default="gpt2-medium",
                        help="Pretrained model to use as encoder")
    parser.add_argument("--epochs", type=int, default=10, metavar="N",
                        help="Number of training epochs")
    parser.add_argument("--batch_size", type=int, default=64, metavar="N",
                        help="input batch size for training (default: 64)")
    parser.add_argument("--log_interval", type=int, default=10, metavar="N",
                        help="how many batches to wait before logging training status")
    parser.add_argument("--save_model", action="store_true",
                        help="whether to save the model")
    parser.add_argument("--cached", action="store_true",
                        help="whether to cache the input representations")
    parser.add_argument("--no_cuda", action="store_true",
                        help="use to turn off cuda")
574
575
576
    args = parser.parse_args()

    train_discriminator(**(vars(args)))