run_pplm_discrim_train.py 18.3 KB
Newer Older
1
2
3
#! /usr/bin/env python3
# coding=utf-8

4
# Copyright (c) 2019 Uber Technologies, Inc.
Rosanne Liu's avatar
Rosanne Liu committed
5
#
6
7
8
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Rosanne Liu's avatar
Rosanne Liu committed
9
#
10
# http://www.apache.org/licenses/LICENSE-2.0
Rosanne Liu's avatar
Rosanne Liu committed
11
#
12
13
14
15
16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

import argparse
import csv
import json
import math
import time

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer
from torchtext import data as torchtext_data
from torchtext import datasets
32
33
34
from tqdm import tqdm, trange

from pplm_classification_head import ClassificationHead
Aymeric Augustin's avatar
Aymeric Augustin committed
35
from transformers import GPT2LMHeadModel, GPT2Tokenizer
36

37
38
39
40
41
42
43
44
45
46
47

torch.manual_seed(0)
np.random.seed(0)
EPSILON = 1e-10
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq = 100


class Discriminator(torch.nn.Module):
    """Transformer encoder followed by a Classification Head"""

48
    def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
Julien Chaumond's avatar
Julien Chaumond committed
49
        super().__init__()
50
51
52
        self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
        self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
        self.embed_size = self.encoder.transformer.config.hidden_size
53
        self.classifier_head = ClassificationHead(class_size=class_size, embed_size=self.embed_size)
54
        self.cached_mode = cached_mode
w4nderlust's avatar
w4nderlust committed
55
        self.device = device
56
57
58
59
60
61
62
63
64
65

    def get_classifier(self):
        return self.classifier_head

    def train_custom(self):
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.classifier_head.train()

    def avg_representation(self, x):
66
        mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
67
68
        hidden, _ = self.encoder.transformer(x)
        masked_hidden = hidden * mask
69
        avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
70
71
72
73
        return avg_hidden

    def forward(self, x):
        if self.cached_mode:
w4nderlust's avatar
w4nderlust committed
74
            avg_hidden = x.to(self.device)
75
        else:
w4nderlust's avatar
w4nderlust committed
76
            avg_hidden = self.avg_representation(x.to(self.device))
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

        logits = self.classifier_head(avg_hidden)
        probs = F.log_softmax(logits, dim=-1)

        return probs


class Dataset(data.Dataset):
    def __init__(self, X, y):
        """Reads source and target sequences from txt files."""
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        data = {}
piero's avatar
piero committed
96
97
        data["X"] = self.X[index]
        data["y"] = self.y[index]
98
99
100
101
102
103
104
        return data


def collate_fn(data):
    def pad_sequences(sequences):
        lengths = [len(seq) for seq in sequences]

105
        padded_sequences = torch.zeros(len(sequences), max(lengths)).long()  # padding value = 0
106
107
108
109
110
111
112
113
114
115
116

        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_sequences[i, :end] = seq[:end]

        return padded_sequences, lengths

    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

piero's avatar
piero committed
117
118
    x_batch, _ = pad_sequences(item_info["X"])
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)
119
120
121
122
123
124
125
126
127

    return x_batch, y_batch


def cached_collate_fn(data):
    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

piero's avatar
piero committed
128
129
    x_batch = torch.cat(item_info["X"], 0)
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)
130
131
132
133

    return x_batch, y_batch


134
def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10, device="cpu"):
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    samples_so_far = 0
    discriminator.train_custom()
    for batch_idx, (input_t, target_t) in enumerate(data_loader):
        input_t, target_t = input_t.to(device), target_t.to(device)

        optimizer.zero_grad()

        output_t = discriminator(input_t)
        loss = F.nll_loss(output_t, target_t)
        loss.backward(retain_graph=True)
        optimizer.step()

        samples_so_far += len(input_t)

        if batch_idx % log_interval == 0:
            print(
piero's avatar
piero committed
151
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
152
                    epoch + 1,
153
154
155
156
                    samples_so_far,
                    len(data_loader.dataset),
                    100 * samples_so_far / len(data_loader.dataset),
                    loss.item(),
157
158
159
160
                )
            )


161
def evaluate_performance(data_loader, discriminator, device="cpu"):
162
163
164
165
166
167
168
169
    discriminator.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for input_t, target_t in data_loader:
            input_t, target_t = input_t.to(device), target_t.to(device)
            output_t = discriminator(input_t)
            # sum up batch loss
piero's avatar
piero committed
170
            test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
171
172
173
174
175
176
177
            # get the index of the max log-probability
            pred_t = output_t.argmax(dim=1, keepdim=True)
            correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()

    test_loss /= len(data_loader.dataset)

    print(
piero's avatar
piero committed
178
179
        "Performance on test set: "
        "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
180
            test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
181
182
183
184
        )
    )


185
def predict(input_sentence, model, classes, cached=False, device="cpu"):
186
    input_t = model.tokenizer.encode(input_sentence)
187
    input_t = torch.tensor([input_t], dtype=torch.long, device=device)
188
189
190
191
    if cached:
        input_t = model.avg_representation(input_t)

    log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
piero's avatar
piero committed
192
    print("Input sentence:", input_sentence)
193
194
195
196
    print(
        "Predictions:",
        ", ".join("{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in zip(classes, log_probs)),
    )
197
198


199
200
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False, device="cpu"):
    data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)
201
202
203

    xs = []
    ys = []
piero's avatar
piero committed
204
    for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
205
206
207
208
209
210
211
212
        with torch.no_grad():
            x = x.to(device)
            avg_rep = discriminator.avg_representation(x).cpu().detach()
            avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
            xs += avg_rep_list
            ys += y.cpu().numpy().tolist()

    data_loader = torch.utils.data.DataLoader(
213
214
        dataset=Dataset(xs, ys), batch_size=batch_size, shuffle=shuffle, collate_fn=cached_collate_fn
    )
215
216
217
218
219

    return data_loader


def train_discriminator(
220
221
222
223
224
225
226
227
228
229
    dataset,
    dataset_fp=None,
    pretrained_model="gpt2-medium",
    epochs=10,
    batch_size=64,
    log_interval=10,
    save_model=False,
    cached=False,
    no_cuda=False,
):
piero's avatar
piero committed
230
    device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
231

piero's avatar
piero committed
232
    print("Preprocessing {} dataset...".format(dataset))
233
234
    start = time.time()

piero's avatar
piero committed
235
    if dataset == "SST":
236
        idx2class = ["positive", "negative", "very positive", "very negative", "neutral"]
237
238
239
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
240
            class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
241
242
243
244
        ).to(device)

        text = torchtext_data.Field()
        label = torchtext_data.Field(sequential=False)
Lysandre's avatar
Lysandre committed
245
246
247
248
249
250
        train_data, val_data, test_data = datasets.SST.splits(
            text,
            label,
            fine_grained=True,
            train_subtrees=True,
        )
251
252
253

        x = []
        y = []
piero's avatar
piero committed
254
        for i in trange(len(train_data), ascii=True):
255
            seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"])
256
257
258
259
260
261
262
263
            seq = discriminator.tokenizer.encode(seq)
            seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
            x.append(seq)
            y.append(class2idx[vars(train_data[i])["label"]])
        train_dataset = Dataset(x, y)

        test_x = []
        test_y = []
piero's avatar
piero committed
264
        for i in trange(len(test_data), ascii=True):
265
            seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"])
266
267
268
269
270
271
272
273
274
275
276
277
278
279
            seq = discriminator.tokenizer.encode(seq)
            seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
            test_x.append(seq)
            test_y.append(class2idx[vars(test_data[i])["label"]])
        test_dataset = Dataset(test_x, test_y)

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 2,
        }

piero's avatar
piero committed
280
    elif dataset == "clickbait":
281
282
283
284
        idx2class = ["non_clickbait", "clickbait"]
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
285
            class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
286
287
288
289
290
291
292
        ).to(device)

        with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
            data = []
            for i, line in enumerate(f):
                try:
                    data.append(eval(line))
293
                except Exception:
294
                    print("Error evaluating line {}: {}".format(i, line))
295
296
297
                    continue
        x = []
        y = []
piero's avatar
piero committed
298
299
300
301
302
303
304
        with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
            for i, line in enumerate(tqdm(f, ascii=True)):
                try:
                    d = eval(line)
                    seq = discriminator.tokenizer.encode(d["text"])

                    if len(seq) < max_length_seq:
305
                        seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
piero's avatar
piero committed
306
                    else:
307
                        print("Line {} is longer than maximum length {}".format(i, max_length_seq))
piero's avatar
piero committed
308
309
310
                        continue
                    x.append(seq)
                    y.append(d["label"])
311
                except Exception:
312
                    print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
piero's avatar
piero committed
313
                    pass
314
315
316
317

        full_dataset = Dataset(x, y)
        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
318
        train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
319
320
321
322
323
324
325
326
327

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 1,
        }

piero's avatar
piero committed
328
    elif dataset == "toxic":
329
330
331
332
        idx2class = ["non_toxic", "toxic"]
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
333
            class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
334
335
        ).to(device)

piero's avatar
piero committed
336
337
        x = []
        y = []
338
        with open("datasets/toxic/toxic_train.txt") as f:
piero's avatar
piero committed
339
            for i, line in enumerate(tqdm(f, ascii=True)):
340
                try:
piero's avatar
piero committed
341
342
343
344
                    d = eval(line)
                    seq = discriminator.tokenizer.encode(d["text"])

                    if len(seq) < max_length_seq:
345
                        seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
piero's avatar
piero committed
346
                    else:
347
                        print("Line {} is longer than maximum length {}".format(i, max_length_seq))
piero's avatar
piero committed
348
349
350
                        continue
                    x.append(seq)
                    y.append(int(np.sum(d["label"]) > 0))
351
                except Exception:
352
                    print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
piero's avatar
piero committed
353
                    pass
354
355
356
357

        full_dataset = Dataset(x, y)
        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
358
        train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
359
360
361
362
363
364
365
366
367

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 0,
        }

piero's avatar
piero committed
368
    else:  # if dataset == "generic":
369
370
371
372
        # This assumes the input dataset is a TSV with the following structure:
        # class \t text

        if dataset_fp is None:
373
            raise ValueError("When generic dataset is selected, " "dataset_fp needs to be specified aswell.")
374
375
376

        classes = set()
        with open(dataset_fp) as f:
piero's avatar
piero committed
377
378
            csv_reader = csv.reader(f, delimiter="\t")
            for row in tqdm(csv_reader, ascii=True):
379
380
                if row:
                    classes.add(row[0])
381
382
383
384
385

        idx2class = sorted(classes)
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
386
            class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
387
388
389
390
391
        ).to(device)

        x = []
        y = []
        with open(dataset_fp) as f:
piero's avatar
piero committed
392
393
            csv_reader = csv.reader(f, delimiter="\t")
            for i, row in enumerate(tqdm(csv_reader, ascii=True)):
394
395
396
397
398
399
                if row:
                    label = row[0]
                    text = row[1]

                    try:
                        seq = discriminator.tokenizer.encode(text)
400
401
                        if len(seq) < max_length_seq:
                            seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
402
403

                        else:
404
                            print("Line {} is longer than maximum length {}".format(i, max_length_seq))
405
406
407
408
409
                            continue

                        x.append(seq)
                        y.append(class2idx[label])

410
                    except Exception:
411
412
                        print("Error tokenizing line {}, skipping it".format(i))
                        pass
413
414
415
416

        full_dataset = Dataset(x, y)
        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
417
        train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
418
419
420
421
422
423
424
425
426
427

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 0,
        }

    end = time.time()
428
    print("Preprocessed {} data points".format(len(train_dataset) + len(test_dataset)))
429
430
431
    print("Data preprocessing took: {:.3f}s".format(end - start))

    if cached:
piero's avatar
piero committed
432
433
        print("Building representation cache...")

434
435
        start = time.time()

436
        train_loader = get_cached_data_loader(train_dataset, batch_size, discriminator, shuffle=True, device=device)
437

438
        test_loader = get_cached_data_loader(test_dataset, batch_size, discriminator, device=device)
439
440
441
442
443

        end = time.time()
        print("Building representation cache took: {:.3f}s".format(end - start))

    else:
444
445
446
447
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
        )
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)
448
449

    if save_model:
450
        with open("{}_classifier_head_meta.json".format(dataset), "w") as meta_file:
451
452
453
454
455
456
            json.dump(discriminator_meta, meta_file)

    optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)

    for epoch in range(epochs):
        start = time.time()
piero's avatar
piero committed
457
        print("\nEpoch", epoch + 1)
458
459
460
461
462
463

        train_epoch(
            discriminator=discriminator,
            data_loader=train_loader,
            optimizer=optimizer,
            epoch=epoch,
w4nderlust's avatar
w4nderlust committed
464
            log_interval=log_interval,
465
            device=device,
466
        )
467
        evaluate_performance(data_loader=test_loader, discriminator=discriminator, device=device)
468
469
470
471
472

        end = time.time()
        print("Epoch took: {:.3f}s".format(end - start))

        print("\nExample prediction")
473
        predict(example_sentence, discriminator, idx2class, cached=cached, device=device)
474
475
476
477

        if save_model:
            # torch.save(discriminator.state_dict(),
            #           "{}_discriminator_{}.pt".format(
478
            #               args.dataset, epoch + 1
479
            #               ))
480
481
482
483
            torch.save(
                discriminator.get_classifier().state_dict(),
                "{}_classifier_head_epoch_{}.pt".format(dataset, epoch + 1),
            )
484
485


piero's avatar
piero committed
486
if __name__ == "__main__":
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    parser = argparse.ArgumentParser(description="Train a discriminator on top of GPT-2 representations")
    parser.add_argument(
        "--dataset",
        type=str,
        default="SST",
        choices=("SST", "clickbait", "toxic", "generic"),
        help="dataset to train the discriminator on."
        "In case of generic, the dataset is expected"
        "to be a TSBV file with structure: class \\t text",
    )
    parser.add_argument(
        "--dataset_fp",
        type=str,
        default="",
        help="File path of the dataset to use. " "Needed only in case of generic datadset",
    )
    parser.add_argument(
        "--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
    )
    parser.add_argument("--epochs", type=int, default=10, metavar="N", help="Number of training epochs")
    parser.add_argument(
        "--batch_size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
    )
    parser.add_argument(
        "--log_interval",
        type=int,
        default=10,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument("--save_model", action="store_true", help="whether to save the model")
    parser.add_argument("--cached", action="store_true", help="whether to cache the input representations")
    parser.add_argument("--no_cuda", action="store_true", help="use to turn off cuda")
520
521
522
    args = parser.parse_args()

    train_discriminator(**(vars(args)))