run_pplm_discrim_train.py 18.3 KB
Newer Older
1
2
3
#! /usr/bin/env python3
# coding=utf-8

4
# Copyright (c) 2019 Uber Technologies, Inc.
Rosanne Liu's avatar
Rosanne Liu committed
5
#
6
7
8
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Rosanne Liu's avatar
Rosanne Liu committed
9
#
10
# http://www.apache.org/licenses/LICENSE-2.0
Rosanne Liu's avatar
Rosanne Liu committed
11
#
12
13
14
15
16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

import argparse
import csv
import json
import math
import time

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer
from torchtext import data as torchtext_data
from torchtext import datasets
32
33
34
from tqdm import tqdm, trange

from pplm_classification_head import ClassificationHead
Aymeric Augustin's avatar
Aymeric Augustin committed
35
from transformers import GPT2LMHeadModel, GPT2Tokenizer
36

37
38
39
40
41
42
43
44
45
46
47

torch.manual_seed(0)
np.random.seed(0)
EPSILON = 1e-10
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq = 100


class Discriminator(torch.nn.Module):
    """Transformer encoder followed by a Classification Head"""

48
    def __init__(self, class_size, pretrained_model="gpt2-medium", cached_mode=False, device="cpu"):
49
50
51
52
        super(Discriminator, self).__init__()
        self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
        self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
        self.embed_size = self.encoder.transformer.config.hidden_size
53
        self.classifier_head = ClassificationHead(class_size=class_size, embed_size=self.embed_size)
54
        self.cached_mode = cached_mode
w4nderlust's avatar
w4nderlust committed
55
        self.device = device
56
57
58
59
60
61
62
63
64
65

    def get_classifier(self):
        return self.classifier_head

    def train_custom(self):
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.classifier_head.train()

    def avg_representation(self, x):
66
        mask = x.ne(0).unsqueeze(2).repeat(1, 1, self.embed_size).float().to(self.device).detach()
67
68
        hidden, _ = self.encoder.transformer(x)
        masked_hidden = hidden * mask
69
        avg_hidden = torch.sum(masked_hidden, dim=1) / (torch.sum(mask, dim=1).detach() + EPSILON)
70
71
72
73
        return avg_hidden

    def forward(self, x):
        if self.cached_mode:
w4nderlust's avatar
w4nderlust committed
74
            avg_hidden = x.to(self.device)
75
        else:
w4nderlust's avatar
w4nderlust committed
76
            avg_hidden = self.avg_representation(x.to(self.device))
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

        logits = self.classifier_head(avg_hidden)
        probs = F.log_softmax(logits, dim=-1)

        return probs


class Dataset(data.Dataset):
    def __init__(self, X, y):
        """Reads source and target sequences from txt files."""
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        data = {}
piero's avatar
piero committed
96
97
        data["X"] = self.X[index]
        data["y"] = self.y[index]
98
99
100
101
102
103
104
        return data


def collate_fn(data):
    def pad_sequences(sequences):
        lengths = [len(seq) for seq in sequences]

105
        padded_sequences = torch.zeros(len(sequences), max(lengths)).long()  # padding value = 0
106
107
108
109
110
111
112
113
114
115
116

        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_sequences[i, :end] = seq[:end]

        return padded_sequences, lengths

    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

piero's avatar
piero committed
117
118
    x_batch, _ = pad_sequences(item_info["X"])
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)
119
120
121
122
123
124
125
126
127

    return x_batch, y_batch


def cached_collate_fn(data):
    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]

piero's avatar
piero committed
128
129
    x_batch = torch.cat(item_info["X"], 0)
    y_batch = torch.tensor(item_info["y"], dtype=torch.long)
130
131
132
133

    return x_batch, y_batch


134
def train_epoch(data_loader, discriminator, optimizer, epoch=0, log_interval=10, device="cpu"):
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    samples_so_far = 0
    discriminator.train_custom()
    for batch_idx, (input_t, target_t) in enumerate(data_loader):
        input_t, target_t = input_t.to(device), target_t.to(device)

        optimizer.zero_grad()

        output_t = discriminator(input_t)
        loss = F.nll_loss(output_t, target_t)
        loss.backward(retain_graph=True)
        optimizer.step()

        samples_so_far += len(input_t)

        if batch_idx % log_interval == 0:
            print(
piero's avatar
piero committed
151
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
152
                    epoch + 1,
153
154
155
156
                    samples_so_far,
                    len(data_loader.dataset),
                    100 * samples_so_far / len(data_loader.dataset),
                    loss.item(),
157
158
159
160
                )
            )


161
def evaluate_performance(data_loader, discriminator, device="cpu"):
162
163
164
165
166
167
168
169
    discriminator.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for input_t, target_t in data_loader:
            input_t, target_t = input_t.to(device), target_t.to(device)
            output_t = discriminator(input_t)
            # sum up batch loss
piero's avatar
piero committed
170
            test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
171
172
173
174
175
176
177
            # get the index of the max log-probability
            pred_t = output_t.argmax(dim=1, keepdim=True)
            correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()

    test_loss /= len(data_loader.dataset)

    print(
piero's avatar
piero committed
178
179
        "Performance on test set: "
        "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
180
            test_loss, correct, len(data_loader.dataset), 100.0 * correct / len(data_loader.dataset)
181
182
183
184
        )
    )


185
def predict(input_sentence, model, classes, cached=False, device="cpu"):
186
    input_t = model.tokenizer.encode(input_sentence)
187
    input_t = torch.tensor([input_t], dtype=torch.long, device=device)
188
189
190
191
    if cached:
        input_t = model.avg_representation(input_t)

    log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
piero's avatar
piero committed
192
    print("Input sentence:", input_sentence)
193
194
195
196
    print(
        "Predictions:",
        ", ".join("{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in zip(classes, log_probs)),
    )
197
198


199
200
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False, device="cpu"):
    data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=collate_fn)
201
202
203

    xs = []
    ys = []
piero's avatar
piero committed
204
    for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
205
206
207
208
209
210
211
212
        with torch.no_grad():
            x = x.to(device)
            avg_rep = discriminator.avg_representation(x).cpu().detach()
            avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
            xs += avg_rep_list
            ys += y.cpu().numpy().tolist()

    data_loader = torch.utils.data.DataLoader(
213
214
        dataset=Dataset(xs, ys), batch_size=batch_size, shuffle=shuffle, collate_fn=cached_collate_fn
    )
215
216
217
218
219

    return data_loader


def train_discriminator(
220
221
222
223
224
225
226
227
228
229
    dataset,
    dataset_fp=None,
    pretrained_model="gpt2-medium",
    epochs=10,
    batch_size=64,
    log_interval=10,
    save_model=False,
    cached=False,
    no_cuda=False,
):
piero's avatar
piero committed
230
    device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
231

piero's avatar
piero committed
232
    print("Preprocessing {} dataset...".format(dataset))
233
234
    start = time.time()

piero's avatar
piero committed
235
    if dataset == "SST":
236
        idx2class = ["positive", "negative", "very positive", "very negative", "neutral"]
237
238
239
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
240
            class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
241
242
243
244
        ).to(device)

        text = torchtext_data.Field()
        label = torchtext_data.Field(sequential=False)
245
        train_data, val_data, test_data = datasets.SST.splits(text, label, fine_grained=True, train_subtrees=True,)
246
247
248

        x = []
        y = []
piero's avatar
piero committed
249
        for i in trange(len(train_data), ascii=True):
250
            seq = TreebankWordDetokenizer().detokenize(vars(train_data[i])["text"])
251
252
253
254
255
256
257
258
            seq = discriminator.tokenizer.encode(seq)
            seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
            x.append(seq)
            y.append(class2idx[vars(train_data[i])["label"]])
        train_dataset = Dataset(x, y)

        test_x = []
        test_y = []
piero's avatar
piero committed
259
        for i in trange(len(test_data), ascii=True):
260
            seq = TreebankWordDetokenizer().detokenize(vars(test_data[i])["text"])
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            seq = discriminator.tokenizer.encode(seq)
            seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
            test_x.append(seq)
            test_y.append(class2idx[vars(test_data[i])["label"]])
        test_dataset = Dataset(test_x, test_y)

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 2,
        }

piero's avatar
piero committed
275
    elif dataset == "clickbait":
276
277
278
279
        idx2class = ["non_clickbait", "clickbait"]
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
280
            class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
281
282
283
284
285
286
287
        ).to(device)

        with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
            data = []
            for i, line in enumerate(f):
                try:
                    data.append(eval(line))
288
                except Exception:
289
                    print("Error evaluating line {}: {}".format(i, line))
290
291
292
                    continue
        x = []
        y = []
piero's avatar
piero committed
293
294
295
296
297
298
299
        with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
            for i, line in enumerate(tqdm(f, ascii=True)):
                try:
                    d = eval(line)
                    seq = discriminator.tokenizer.encode(d["text"])

                    if len(seq) < max_length_seq:
300
                        seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
piero's avatar
piero committed
301
                    else:
302
                        print("Line {} is longer than maximum length {}".format(i, max_length_seq))
piero's avatar
piero committed
303
304
305
                        continue
                    x.append(seq)
                    y.append(d["label"])
306
                except Exception:
307
                    print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
piero's avatar
piero committed
308
                    pass
309
310
311
312

        full_dataset = Dataset(x, y)
        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
313
        train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
314
315
316
317
318
319
320
321
322

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 1,
        }

piero's avatar
piero committed
323
    elif dataset == "toxic":
324
325
326
327
        idx2class = ["non_toxic", "toxic"]
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
328
            class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
329
330
        ).to(device)

piero's avatar
piero committed
331
332
        x = []
        y = []
333
        with open("datasets/toxic/toxic_train.txt") as f:
piero's avatar
piero committed
334
            for i, line in enumerate(tqdm(f, ascii=True)):
335
                try:
piero's avatar
piero committed
336
337
338
339
                    d = eval(line)
                    seq = discriminator.tokenizer.encode(d["text"])

                    if len(seq) < max_length_seq:
340
                        seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
piero's avatar
piero committed
341
                    else:
342
                        print("Line {} is longer than maximum length {}".format(i, max_length_seq))
piero's avatar
piero committed
343
344
345
                        continue
                    x.append(seq)
                    y.append(int(np.sum(d["label"]) > 0))
346
                except Exception:
347
                    print("Error evaluating / tokenizing" " line {}, skipping it".format(i))
piero's avatar
piero committed
348
                    pass
349
350
351
352

        full_dataset = Dataset(x, y)
        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
353
        train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
354
355
356
357
358
359
360
361
362

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 0,
        }

piero's avatar
piero committed
363
    else:  # if dataset == "generic":
364
365
366
367
        # This assumes the input dataset is a TSV with the following structure:
        # class \t text

        if dataset_fp is None:
368
            raise ValueError("When generic dataset is selected, " "dataset_fp needs to be specified aswell.")
369
370
371

        classes = set()
        with open(dataset_fp) as f:
piero's avatar
piero committed
372
373
            csv_reader = csv.reader(f, delimiter="\t")
            for row in tqdm(csv_reader, ascii=True):
374
375
                if row:
                    classes.add(row[0])
376
377
378
379
380

        idx2class = sorted(classes)
        class2idx = {c: i for i, c in enumerate(idx2class)}

        discriminator = Discriminator(
381
            class_size=len(idx2class), pretrained_model=pretrained_model, cached_mode=cached, device=device
382
383
384
385
386
        ).to(device)

        x = []
        y = []
        with open(dataset_fp) as f:
piero's avatar
piero committed
387
388
            csv_reader = csv.reader(f, delimiter="\t")
            for i, row in enumerate(tqdm(csv_reader, ascii=True)):
389
390
391
392
393
394
                if row:
                    label = row[0]
                    text = row[1]

                    try:
                        seq = discriminator.tokenizer.encode(text)
395
396
                        if len(seq) < max_length_seq:
                            seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
397
398

                        else:
399
                            print("Line {} is longer than maximum length {}".format(i, max_length_seq))
400
401
402
403
404
                            continue

                        x.append(seq)
                        y.append(class2idx[label])

405
                    except Exception:
406
407
                        print("Error tokenizing line {}, skipping it".format(i))
                        pass
408
409
410
411

        full_dataset = Dataset(x, y)
        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
412
        train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
413
414
415
416
417
418
419
420
421
422

        discriminator_meta = {
            "class_size": len(idx2class),
            "embed_size": discriminator.embed_size,
            "pretrained_model": pretrained_model,
            "class_vocab": class2idx,
            "default_class": 0,
        }

    end = time.time()
423
    print("Preprocessed {} data points".format(len(train_dataset) + len(test_dataset)))
424
425
426
    print("Data preprocessing took: {:.3f}s".format(end - start))

    if cached:
piero's avatar
piero committed
427
428
        print("Building representation cache...")

429
430
        start = time.time()

431
        train_loader = get_cached_data_loader(train_dataset, batch_size, discriminator, shuffle=True, device=device)
432

433
        test_loader = get_cached_data_loader(test_dataset, batch_size, discriminator, device=device)
434
435
436
437
438

        end = time.time()
        print("Building representation cache took: {:.3f}s".format(end - start))

    else:
439
440
441
442
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
        )
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)
443
444

    if save_model:
445
        with open("{}_classifier_head_meta.json".format(dataset), "w") as meta_file:
446
447
448
449
450
451
            json.dump(discriminator_meta, meta_file)

    optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)

    for epoch in range(epochs):
        start = time.time()
piero's avatar
piero committed
452
        print("\nEpoch", epoch + 1)
453
454
455
456
457
458

        train_epoch(
            discriminator=discriminator,
            data_loader=train_loader,
            optimizer=optimizer,
            epoch=epoch,
w4nderlust's avatar
w4nderlust committed
459
            log_interval=log_interval,
460
            device=device,
461
        )
462
        evaluate_performance(data_loader=test_loader, discriminator=discriminator, device=device)
463
464
465
466
467

        end = time.time()
        print("Epoch took: {:.3f}s".format(end - start))

        print("\nExample prediction")
468
        predict(example_sentence, discriminator, idx2class, cached=cached, device=device)
469
470
471
472

        if save_model:
            # torch.save(discriminator.state_dict(),
            #           "{}_discriminator_{}.pt".format(
473
            #               args.dataset, epoch + 1
474
            #               ))
475
476
477
478
            torch.save(
                discriminator.get_classifier().state_dict(),
                "{}_classifier_head_epoch_{}.pt".format(dataset, epoch + 1),
            )
479
480


piero's avatar
piero committed
481
if __name__ == "__main__":
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
    parser = argparse.ArgumentParser(description="Train a discriminator on top of GPT-2 representations")
    parser.add_argument(
        "--dataset",
        type=str,
        default="SST",
        choices=("SST", "clickbait", "toxic", "generic"),
        help="dataset to train the discriminator on."
        "In case of generic, the dataset is expected"
        "to be a TSBV file with structure: class \\t text",
    )
    parser.add_argument(
        "--dataset_fp",
        type=str,
        default="",
        help="File path of the dataset to use. " "Needed only in case of generic datadset",
    )
    parser.add_argument(
        "--pretrained_model", type=str, default="gpt2-medium", help="Pretrained model to use as encoder"
    )
    parser.add_argument("--epochs", type=int, default=10, metavar="N", help="Number of training epochs")
    parser.add_argument(
        "--batch_size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
    )
    parser.add_argument(
        "--log_interval",
        type=int,
        default=10,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument("--save_model", action="store_true", help="whether to save the model")
    parser.add_argument("--cached", action="store_true", help="whether to cache the input representations")
    parser.add_argument("--no_cuda", action="store_true", help="use to turn off cuda")
515
516
517
    args = parser.parse_args()

    train_discriminator(**(vars(args)))