run_pplm.py 27.3 KB
Newer Older
Piero Molino's avatar
Piero Molino committed
1
#! /usr/bin/env python3
Julien Chaumond's avatar
Julien Chaumond committed
2
# coding=utf-8
Rosanne Liu's avatar
Rosanne Liu committed
3

4
# Copyright (c) 2019 Uber Technologies, Inc.
Julien Chaumond's avatar
Julien Chaumond committed
5
#
6
7
8
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Julien Chaumond's avatar
Julien Chaumond committed
9
#
10
# http://www.apache.org/licenses/LICENSE-2.0
Julien Chaumond's avatar
Julien Chaumond committed
11
#
12
13
14
15
16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Julien Chaumond's avatar
Julien Chaumond committed
17
18
19

"""
Example command with bag of words:
20
python run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
Julien Chaumond's avatar
Julien Chaumond committed
21
22

Example command with discriminator:
23
python run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
Julien Chaumond's avatar
Julien Chaumond committed
24
25
26
"""

import argparse
27
import json
Julien Chaumond's avatar
Julien Chaumond committed
28
29
30
31
32
33
34
35
from operator import add
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import trange

Aymeric Augustin's avatar
Aymeric Augustin committed
36
from pplm_classification_head import ClassificationHead
Julien Chaumond's avatar
Julien Chaumond committed
37
38
39
from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel
Aymeric Augustin's avatar
Aymeric Augustin committed
40

Julien Chaumond's avatar
Julien Chaumond committed
41
42
43
44
45

PPLM_BOW = 1
PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
46
BIG_CONST = 1e10
Julien Chaumond's avatar
Julien Chaumond committed
47
48

BAG_OF_WORDS_ARCHIVE_MAP = {
49
50
51
52
53
54
55
    "legal": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
    "military": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
    "politics": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
    "religion": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
    "science": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
    "space": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
    "technology": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
Julien Chaumond's avatar
Julien Chaumond committed
56
57
58
59
}

DISCRIMINATOR_MODELS_PARAMS = {
    "clickbait": {
Julien Chaumond's avatar
Julien Chaumond committed
60
        "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt",
Julien Chaumond's avatar
Julien Chaumond committed
61
62
63
64
        "class_size": 2,
        "embed_size": 1024,
        "class_vocab": {"non_clickbait": 0, "clickbait": 1},
        "default_class": 1,
65
        "pretrained_model": "gpt2-medium",
Julien Chaumond's avatar
Julien Chaumond committed
66
67
    },
    "sentiment": {
Julien Chaumond's avatar
Julien Chaumond committed
68
        "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt",
Julien Chaumond's avatar
Julien Chaumond committed
69
70
71
72
        "class_size": 5,
        "embed_size": 1024,
        "class_vocab": {"very_positive": 2, "very_negative": 3},
        "default_class": 3,
73
        "pretrained_model": "gpt2-medium",
Julien Chaumond's avatar
Julien Chaumond committed
74
75
76
77
    },
}


Piero Molino's avatar
Piero Molino committed
78
79
80
81
82
83
84
85
86
87
88
89
def top_k_filter(logits, k, probs=False):
    """
    Masks everything but the k top entries as -infinity (1e10).
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
    if k == 0:
        return logits
    else:
        values = torch.topk(logits, k)[0]
        batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
        if probs:
90
91
            return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits)
        return torch.where(logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits)
Piero Molino's avatar
Piero Molino committed
92
93


94
def perturb_past(
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    past,
    model,
    last,
    unpert_past=None,
    unpert_logits=None,
    accumulated_hidden=None,
    grad_norms=None,
    stepsize=0.01,
    one_hot_bows_vectors=None,
    classifier=None,
    class_label=None,
    loss_type=0,
    num_iterations=3,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    kl_scale=0.01,
    device="cuda",
114
):
Piero Molino's avatar
Piero Molino committed
115
    # Generate inital perturbed past
116
    grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past]
Julien Chaumond's avatar
Julien Chaumond committed
117
118
119
120

    if accumulated_hidden is None:
        accumulated_hidden = 0

121
    if decay:
122
        decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (window_length))[1:]
Julien Chaumond's avatar
Julien Chaumond committed
123
124
125
    else:
        decay_mask = 1.0

126
    # TODO fix this comment (SUMANTH)
Piero Molino's avatar
Piero Molino committed
127
    # Generate a mask is gradient perturbated is based on a past window
128
    _, _, _, curr_length, _ = past[0].shape
Piero Molino's avatar
Piero Molino committed
129

130
    if curr_length > window_length and window_length > 0:
131
        ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:])
Piero Molino's avatar
Piero Molino committed
132

133
        zeros_key_val_shape = (
134
            tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
135
        )
Julien Chaumond's avatar
Julien Chaumond committed
136
137
138
139
140

        ones_mask = torch.ones(ones_key_val_shape)
        ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
        ones_mask = ones_mask.permute(0, 1, 2, 4, 3)

141
        window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).to(device)
Julien Chaumond's avatar
Julien Chaumond committed
142
    else:
143
        window_mask = torch.ones_like(past[0]).to(device)
Julien Chaumond's avatar
Julien Chaumond committed
144

145
    # accumulate perturbations for num_iterations
Julien Chaumond's avatar
Julien Chaumond committed
146
    loss_per_iter = []
147
    new_accumulated_hidden = None
148
    for i in range(num_iterations):
Julien Chaumond's avatar
Julien Chaumond committed
149
        print("Iteration ", i + 1)
150
        curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
151
152
153
154
155

        # Compute hidden using perturbed past
        perturbed_past = list(map(add, past, curr_perturbation))
        _, _, _, curr_length, _ = curr_perturbation[0].shape
        all_logits, _, all_hidden = model(last, past=perturbed_past)
Piero Molino's avatar
Piero Molino committed
156
        hidden = all_hidden[-1]
157
        new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
158
159
160
        # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
        logits = all_logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
Piero Molino's avatar
Piero Molino committed
161
162
163

        loss = 0.0
        loss_list = []
164
165
166
167
168
169
        if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
            for one_hot_bow in one_hot_bows_vectors:
                bow_logits = torch.mm(probs, torch.t(one_hot_bow))
                bow_loss = -torch.log(torch.sum(bow_logits))
                loss += bow_loss
                loss_list.append(bow_loss)
Piero Molino's avatar
Piero Molino committed
170
171
            print(" pplm_bow_loss:", loss.data.cpu().numpy())

172
        if loss_type == 2 or loss_type == 3:
Julien Chaumond's avatar
Julien Chaumond committed
173
            ce_loss = torch.nn.CrossEntropyLoss()
174
175
176
177
178
179
            # TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
            curr_unpert_past = unpert_past
            curr_probs = torch.unsqueeze(probs, dim=1)
            wte = model.resize_token_embeddings()
            for _ in range(horizon_length):
                inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
180
                _, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past, inputs_embeds=inputs_embeds)
181
                curr_hidden = curr_all_hidden[-1]
182
                new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
Julien Chaumond's avatar
Julien Chaumond committed
183

184
            prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length))
Julien Chaumond's avatar
Julien Chaumond committed
185

186
            label = torch.tensor(prediction.shape[0] * [class_label], device=device, dtype=torch.long)
187
            discrim_loss = ce_loss(prediction, label)
Julien Chaumond's avatar
Julien Chaumond committed
188
            print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
Piero Molino's avatar
Piero Molino committed
189
190
            loss += discrim_loss
            loss_list.append(discrim_loss)
Julien Chaumond's avatar
Julien Chaumond committed
191

Piero Molino's avatar
Piero Molino committed
192
193
        kl_loss = 0.0
        if kl_scale > 0.0:
194
            unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
195
196
            unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
            correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
197
            corrected_probs = probs + correction.detach()
198
199
            kl_loss = kl_scale * ((corrected_probs * (corrected_probs / unpert_probs).log()).sum())
            print(" kl_loss", kl_loss.data.cpu().numpy())
200
            loss += kl_loss
Julien Chaumond's avatar
Julien Chaumond committed
201
202

        loss_per_iter.append(loss.data.cpu().numpy())
203
        print(" pplm_loss", (loss - kl_loss).data.cpu().numpy())
Julien Chaumond's avatar
Julien Chaumond committed
204

205
        # compute gradients
Rosanne Liu's avatar
Rosanne Liu committed
206
        loss.backward()
207
208
209

        # calculate gradient norms
        if grad_norms is not None and loss_type == PPLM_BOW:
Julien Chaumond's avatar
Julien Chaumond committed
210
211
            grad_norms = [
                torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
212
213
                for index, p_ in enumerate(curr_perturbation)
            ]
Julien Chaumond's avatar
Julien Chaumond committed
214
        else:
215
            grad_norms = [
216
                (torch.norm(p_.grad * window_mask) + SMALL_CONST) for index, p_ in enumerate(curr_perturbation)
217
            ]
Julien Chaumond's avatar
Julien Chaumond committed
218

219
        # normalize gradients
Julien Chaumond's avatar
Julien Chaumond committed
220
        grad = [
221
            -stepsize * (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
222
223
            for index, p_ in enumerate(curr_perturbation)
        ]
Julien Chaumond's avatar
Julien Chaumond committed
224

225
226
227
228
229
        # accumulate gradient
        grad_accumulator = list(map(add, grad, grad_accumulator))

        # reset gradients, just to make sure
        for p_ in curr_perturbation:
Julien Chaumond's avatar
Julien Chaumond committed
230
231
            p_.grad.data.zero_()

232
        # removing past from the graph
Julien Chaumond's avatar
Julien Chaumond committed
233
        new_past = []
234
235
        for p_ in past:
            new_past.append(p_.detach())
Julien Chaumond's avatar
Julien Chaumond committed
236
237
        past = new_past

238
    # apply the accumulated perturbations to the past
239
    grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
240
    pert_past = list(map(add, past, grad_accumulator))
Julien Chaumond's avatar
Julien Chaumond committed
241

242
    return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
Julien Chaumond's avatar
Julien Chaumond committed
243
244
245


def get_classifier(
246
    name: Optional[str], class_label: Union[str, int], device: str
Julien Chaumond's avatar
Julien Chaumond committed
247
248
249
250
251
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
    if name is None:
        return None, None

    params = DISCRIMINATOR_MODELS_PARAMS[name]
252
    classifier = ClassificationHead(class_size=params["class_size"], embed_size=params["embed_size"]).to(device)
253
254
    if "url" in params:
        resolved_archive_file = cached_path(params["url"])
255
    elif "path" in params:
256
        resolved_archive_file = params["path"]
257
    else:
258
        raise ValueError("Either url or path have to be specified in the discriminator model parameters")
259
    classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
Julien Chaumond's avatar
Julien Chaumond committed
260
261
    classifier.eval()

262
263
264
    if isinstance(class_label, str):
        if class_label in params["class_vocab"]:
            label_id = params["class_vocab"][class_label]
Julien Chaumond's avatar
Julien Chaumond committed
265
266
        else:
            label_id = params["default_class"]
267
            print("class_label {} not in class_vocab".format(class_label))
Julien Chaumond's avatar
Julien Chaumond committed
268
269
270
            print("available values are: {}".format(params["class_vocab"]))
            print("using default class {}".format(label_id))

271
272
273
    elif isinstance(class_label, int):
        if class_label in set(params["class_vocab"].values()):
            label_id = class_label
Julien Chaumond's avatar
Julien Chaumond committed
274
275
        else:
            label_id = params["default_class"]
276
            print("class_label {} not in class_vocab".format(class_label))
Julien Chaumond's avatar
Julien Chaumond committed
277
278
279
280
281
282
283
284
285
            print("available values are: {}".format(params["class_vocab"]))
            print("using default class {}".format(label_id))

    else:
        label_id = params["default_class"]

    return classifier, label_id


286
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> List[List[List[int]]]:
Julien Chaumond's avatar
Julien Chaumond committed
287
288
289
290
291
292
293
    bow_indices = []
    for id_or_path in bag_of_words_ids_or_paths:
        if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
            filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path])
        else:
            filepath = id_or_path
        with open(filepath, "r") as f:
Piero Molino's avatar
Piero Molino committed
294
            words = f.read().strip().split("\n")
295
        bow_indices.append([tokenizer.encode(word.strip(), add_prefix_space=True) for word in words])
Julien Chaumond's avatar
Julien Chaumond committed
296
297
298
    return bow_indices


299
def build_bows_one_hot_vectors(bow_indices, tokenizer, device="cuda"):
Julien Chaumond's avatar
Julien Chaumond committed
300
301
302
303
304
305
    if bow_indices is None:
        return None

    one_hot_bows_vectors = []
    for single_bow in bow_indices:
        single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
306
        single_bow = torch.tensor(single_bow).to(device)
Julien Chaumond's avatar
Julien Chaumond committed
307
        num_words = single_bow.shape[0]
308
        one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
Julien Chaumond's avatar
Julien Chaumond committed
309
310
311
312
313
        one_hot_bow.scatter_(1, single_bow, 1)
        one_hot_bows_vectors.append(one_hot_bow)
    return one_hot_bows_vectors


314
def full_text_generation(
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    model,
    tokenizer,
    context=None,
    num_samples=1,
    device="cuda",
    bag_of_words=None,
    discrim=None,
    class_label=None,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
336
    repetition_penalty=1.0,
337
    **kwargs
338
):
339
    classifier, class_id = get_classifier(discrim, class_label, device)
Julien Chaumond's avatar
Julien Chaumond committed
340

341
342
    bow_indices = []
    if bag_of_words:
343
        bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
Piero Molino's avatar
Piero Molino committed
344

345
    if bag_of_words and classifier:
Julien Chaumond's avatar
Julien Chaumond committed
346
        print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
347
        loss_type = PPLM_BOW_DISCRIM
Julien Chaumond's avatar
Julien Chaumond committed
348

349
350
    elif bag_of_words:
        loss_type = PPLM_BOW
Julien Chaumond's avatar
Julien Chaumond committed
351
352
353
        print("Using PPLM-BoW")

    elif classifier is not None:
354
        loss_type = PPLM_DISCRIM
Julien Chaumond's avatar
Julien Chaumond committed
355
356
357
        print("Using PPLM-Discrim")

    else:
358
        raise Exception("Specify either a bag of words or a discriminator")
Julien Chaumond's avatar
Julien Chaumond committed
359

360
    unpert_gen_tok_text, _, _ = generate_text_pplm(
361
362
363
364
365
366
367
368
        model=model,
        tokenizer=tokenizer,
        context=context,
        device=device,
        length=length,
        sample=sample,
        perturb=False,
        repetition_penalty=repetition_penalty,
369
    )
370
    if device == "cuda":
371
        torch.cuda.empty_cache()
Julien Chaumond's avatar
Julien Chaumond committed
372

373
374
375
    pert_gen_tok_texts = []
    discrim_losses = []
    losses_in_time = []
Piero Molino's avatar
Piero Molino committed
376

377
    for i in range(num_samples):
378
        pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
379
            model=model,
380
            tokenizer=tokenizer,
381
382
383
384
385
            context=context,
            device=device,
            perturb=True,
            bow_indices=bow_indices,
            classifier=classifier,
386
            class_label=class_id,
387
388
389
390
391
            loss_type=loss_type,
            length=length,
            stepsize=stepsize,
            temperature=temperature,
            top_k=top_k,
392
393
394
            sample=sample,
            num_iterations=num_iterations,
            grad_length=grad_length,
395
            horizon_length=horizon_length,
396
            window_length=window_length,
397
398
            decay=decay,
            gamma=gamma,
399
400
            gm_scale=gm_scale,
            kl_scale=kl_scale,
401
            repetition_penalty=repetition_penalty,
402
        )
403
        pert_gen_tok_texts.append(pert_gen_tok_text)
Julien Chaumond's avatar
Julien Chaumond committed
404
        if classifier is not None:
405
406
            discrim_losses.append(discrim_loss.data.cpu().numpy())
        losses_in_time.append(loss_in_time)
Julien Chaumond's avatar
Julien Chaumond committed
407

408
    if device == "cuda":
409
        torch.cuda.empty_cache()
Julien Chaumond's avatar
Julien Chaumond committed
410

411
    return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
Julien Chaumond's avatar
Julien Chaumond committed
412

413
414

def generate_text_pplm(
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    model,
    tokenizer,
    context=None,
    past=None,
    device="cuda",
    perturb=True,
    bow_indices=None,
    classifier=None,
    class_label=None,
    loss_type=0,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
438
    repetition_penalty=1.0,
439
):
440
441
442
443
444
445
    output_so_far = None
    if context:
        context_t = torch.tensor(context, device=device, dtype=torch.long)
        while len(context_t.shape) < 2:
            context_t = context_t.unsqueeze(0)
        output_so_far = context_t
Julien Chaumond's avatar
Julien Chaumond committed
446

447
    # collect one hot vectors for bags of words
448
    one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)
449

Julien Chaumond's avatar
Julien Chaumond committed
450
    grad_norms = None
451
    last = None
452
    unpert_discrim_loss = 0
Julien Chaumond's avatar
Julien Chaumond committed
453
    loss_in_time = []
454
    for i in trange(length, ascii=True):
Julien Chaumond's avatar
Julien Chaumond committed
455
456

        # Get past/probs for current output, except for last word
457
        # Note that GPT takes 2 inputs: past + current_token
Julien Chaumond's avatar
Julien Chaumond committed
458

459
460
461
        # run model forward to obtain unperturbed
        if past is None and output_so_far is not None:
            last = output_so_far[:, -1:]
462
463
            if output_so_far.shape[1] > 1:
                _, past, _ = model(output_so_far[:, :-1])
Piero Molino's avatar
Piero Molino committed
464

465
466
        unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
        unpert_last_hidden = unpert_all_hidden[-1]
Piero Molino's avatar
Piero Molino committed
467

468
        # check if we are abowe grad max length
469
470
        if i >= grad_length:
            current_stepsize = stepsize * 0
Julien Chaumond's avatar
Julien Chaumond committed
471
        else:
472
            current_stepsize = stepsize
Julien Chaumond's avatar
Julien Chaumond committed
473

474
        # modify the past if necessary
475
        if not perturb or num_iterations == 0:
476
            pert_past = past
Julien Chaumond's avatar
Julien Chaumond committed
477
478

        else:
479
            accumulated_hidden = unpert_last_hidden[:, :-1, :]
Julien Chaumond's avatar
Julien Chaumond committed
480
481
            accumulated_hidden = torch.sum(accumulated_hidden, dim=1)

482
483
484
485
486
487
488
489
490
491
            if past is not None:
                pert_past, _, grad_norms, loss_this_iter = perturb_past(
                    past,
                    model,
                    last,
                    unpert_past=unpert_past,
                    unpert_logits=unpert_logits,
                    accumulated_hidden=accumulated_hidden,
                    grad_norms=grad_norms,
                    stepsize=current_stepsize,
492
                    one_hot_bows_vectors=one_hot_bows_vectors,
493
                    classifier=classifier,
494
                    class_label=class_label,
495
496
497
                    loss_type=loss_type,
                    num_iterations=num_iterations,
                    horizon_length=horizon_length,
498
                    window_length=window_length,
499
500
                    decay=decay,
                    gamma=gamma,
501
502
                    kl_scale=kl_scale,
                    device=device,
503
504
505
506
                )
                loss_in_time.append(loss_this_iter)
            else:
                pert_past = past
Piero Molino's avatar
Piero Molino committed
507

508
509
        pert_logits, past, pert_all_hidden = model(last, past=pert_past)
        pert_logits = pert_logits[:, -1, :] / temperature  # + SMALL_CONST
510
511
512
513
514
515
516

        for token_idx in set(output_so_far[0].tolist()):
            if pert_logits[0, token_idx] < 0:
                pert_logits[0, token_idx] *= repetition_penalty
            else:
                pert_logits[0, token_idx] /= repetition_penalty

517
        pert_probs = F.softmax(pert_logits, dim=-1)
Julien Chaumond's avatar
Julien Chaumond committed
518
519

        if classifier is not None:
Piero Molino's avatar
Piero Molino committed
520
            ce_loss = torch.nn.CrossEntropyLoss()
521
            prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
522
            label = torch.tensor([class_label], device=device, dtype=torch.long)
523
            unpert_discrim_loss = ce_loss(prediction, label)
524
            print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy())
Julien Chaumond's avatar
Julien Chaumond committed
525
        else:
526
            unpert_discrim_loss = 0
Piero Molino's avatar
Piero Molino committed
527
528

        # Fuse the modified model and original model
Julien Chaumond's avatar
Julien Chaumond committed
529
530
        if perturb:

531
            unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
Piero Molino's avatar
Piero Molino committed
532

533
534
            pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale))  # + SMALL_CONST
            pert_probs = top_k_filter(pert_probs, k=top_k, probs=True)  # + SMALL_CONST
Julien Chaumond's avatar
Julien Chaumond committed
535

536
537
538
            # rescale
            if torch.sum(pert_probs) <= 1:
                pert_probs = pert_probs / torch.sum(pert_probs)
Julien Chaumond's avatar
Julien Chaumond committed
539
540

        else:
541
542
            pert_logits = top_k_filter(pert_logits, k=top_k)  # + SMALL_CONST
            pert_probs = F.softmax(pert_logits, dim=-1)
Julien Chaumond's avatar
Julien Chaumond committed
543

544
        # sample or greedy
Julien Chaumond's avatar
Julien Chaumond committed
545
        if sample:
546
547
            last = torch.multinomial(pert_probs, num_samples=1)

Julien Chaumond's avatar
Julien Chaumond committed
548
        else:
549
550
551
            _, last = torch.topk(pert_probs, k=1, dim=-1)

        # update context/output_so_far appending the new token
552
        output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1)
553

554
        print(tokenizer.decode(output_so_far.tolist()[0]))
555
556

    return output_so_far, unpert_discrim_loss, loss_in_time
Julien Chaumond's avatar
Julien Chaumond committed
557
558


559
560
def set_generic_model_params(discrim_weights, discrim_meta):
    if discrim_weights is None:
561
        raise ValueError("When using a generic discriminator, discrim_weights need to be specified")
562
    if discrim_meta is None:
563
        raise ValueError("When using a generic discriminator, discrim_meta need to be specified")
564

565
    with open(discrim_meta, "r") as discrim_meta_file:
566
        meta = json.load(discrim_meta_file)
567
568
    meta["path"] = discrim_weights
    DISCRIMINATOR_MODELS_PARAMS["generic"] = meta
569
570


571
def run_pplm_example(
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
    pretrained_model="gpt2-medium",
    cond_text="",
    uncond=False,
    num_samples=1,
    bag_of_words=None,
    discrim=None,
    discrim_weights=None,
    discrim_meta=None,
    class_label=-1,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
    seed=0,
    no_cuda=False,
    colorama=False,
597
    repetition_penalty=1.0,
598
):
599
    # set Random seed
600
601
    torch.manual_seed(seed)
    np.random.seed(seed)
Julien Chaumond's avatar
Julien Chaumond committed
602

603
    # set the device
604
605
    device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"

606
    if discrim == "generic":
607
        set_generic_model_params(discrim_weights, discrim_meta)
Julien Chaumond's avatar
Julien Chaumond committed
608

609
    if discrim is not None:
610
        pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
611
        print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model))
612

613
    # load pretrained model
614
    model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
Julien Chaumond's avatar
Julien Chaumond committed
615
616
617
    model.to(device)
    model.eval()

618
619
620
    # load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)

Piero Molino's avatar
Piero Molino committed
621
    # Freeze GPT-2 weights
Julien Chaumond's avatar
Julien Chaumond committed
622
623
624
    for param in model.parameters():
        param.requires_grad = False

625
    # figure out conditioning text
626
    if uncond:
627
        tokenized_cond_text = tokenizer.encode([tokenizer.bos_token])
Julien Chaumond's avatar
Julien Chaumond committed
628
    else:
629
        raw_text = cond_text
Julien Chaumond's avatar
Julien Chaumond committed
630
        while not raw_text:
631
            print("Did you forget to add `--cond_text`? ")
Julien Chaumond's avatar
Julien Chaumond committed
632
            raw_text = input("Model prompt >>> ")
633
        tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)
Piero Molino's avatar
Piero Molino committed
634

635
    print("= Prefix of sentence =")
636
    print(tokenizer.decode(tokenized_cond_text))
637
    print()
Piero Molino's avatar
Piero Molino committed
638

639
    # generate unperturbed and perturbed texts
Piero Molino's avatar
Piero Molino committed
640

641
642
643
    # full_text_generation returns:
    # unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
    unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        model=model,
        tokenizer=tokenizer,
        context=tokenized_cond_text,
        device=device,
        num_samples=num_samples,
        bag_of_words=bag_of_words,
        discrim=discrim,
        class_label=class_label,
        length=length,
        stepsize=stepsize,
        temperature=temperature,
        top_k=top_k,
        sample=sample,
        num_iterations=num_iterations,
        grad_length=grad_length,
        horizon_length=horizon_length,
        window_length=window_length,
        decay=decay,
        gamma=gamma,
        gm_scale=gm_scale,
        kl_scale=kl_scale,
665
        repetition_penalty=repetition_penalty,
666
667
668
    )

    # untokenize unperturbed text
669
    unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])
Piero Molino's avatar
Piero Molino committed
670

671
672
673
674
    print("=" * 80)
    print("= Unperturbed generated text =")
    print(unpert_gen_text)
    print()
Piero Molino's avatar
Piero Molino committed
675

676
677
    generated_texts = []

678
    bow_word_ids = set()
679
    if bag_of_words and colorama:
680
        bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
681
682
683
684
685
        for single_bow_list in bow_indices:
            # filtering all words in the list composed of more than 1 token
            filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
            # w[0] because we are sure w has only 1 item because previous fitler
            bow_word_ids.update(w[0] for w in filtered)
686
687
688
689
690

    # iterate through the perturbed texts
    for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
        try:
            # untokenize unperturbed text
691
            if colorama:
Piero Molino's avatar
Piero Molino committed
692
693
                import colorama

694
                pert_gen_text = ""
695
                for word_id in pert_gen_tok_text.tolist()[0]:
696
                    if word_id in bow_word_ids:
697
                        pert_gen_text += "{}{}{}".format(
698
                            colorama.Fore.RED, tokenizer.decode([word_id]), colorama.Style.RESET_ALL,
699
                        )
Piero Molino's avatar
Piero Molino committed
700
                    else:
701
                        pert_gen_text += tokenizer.decode([word_id])
Piero Molino's avatar
Piero Molino committed
702
            else:
703
                pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
Julien Chaumond's avatar
Julien Chaumond committed
704

705
706
707
            print("= Perturbed generated text {} =".format(i + 1))
            print(pert_gen_text)
            print()
708
709
        except Exception as exc:
            print("Ignoring error while generating perturbed text:", exc)
Julien Chaumond's avatar
Julien Chaumond committed
710

711
        # keep the prefix, perturbed seq, original seq for each index
712
        generated_texts.append((tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text))
Julien Chaumond's avatar
Julien Chaumond committed
713

Piero Molino's avatar
Piero Molino committed
714
    return
Julien Chaumond's avatar
Julien Chaumond committed
715
716


717
if __name__ == "__main__":
718
719
720
721
722
723
724
725
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--pretrained_model",
        "-M",
        type=str,
        default="gpt2-medium",
        help="pretrained model name or path to local checkpoint",
    )
726
727
    parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on")
    parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix")
728
    parser.add_argument(
729
        "--num_samples", type=int, default=1, help="Number of samples to generate from the modified latents",
730
    )
731
732
733
734
735
    parser.add_argument(
        "--bag_of_words",
        "-B",
        type=str,
        default=None,
736
737
738
739
740
        help=(
            "Bags of words used for PPLM-BoW. "
            "Either a BOW id (see list in code) or a filepath. "
            "Multiple BoWs separated by ;"
        ),
741
742
743
744
745
746
747
748
749
750
    )
    parser.add_argument(
        "--discrim",
        "-D",
        type=str,
        default=None,
        choices=("clickbait", "sentiment", "toxicity", "generic"),
        help="Discriminator to use",
    )
    parser.add_argument(
751
752
753
754
        "--discrim_weights", type=str, default=None, help="Weights for the generic discriminator",
    )
    parser.add_argument(
        "--discrim_meta", type=str, default=None, help="Meta information for the generic discriminator",
755
756
757
    )
    parser.add_argument(
        "--class_label", type=int, default=-1, help="Class label used for the discriminator",
758
759
    )
    parser.add_argument("--length", type=int, default=100)
760
    parser.add_argument("--stepsize", type=float, default=0.02)
761
762
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=10)
763
    parser.add_argument("--sample", action="store_true", help="Generate from end-of-text as prefix")
764
765
766
    parser.add_argument("--num_iterations", type=int, default=3)
    parser.add_argument("--grad_length", type=int, default=10000)
    parser.add_argument(
767
        "--window_length",
768
        type=int,
769
        default=0,
770
        help="Length of past which is being optimized; 0 corresponds to infinite window length",
771
772
    )
    parser.add_argument(
773
        "--horizon_length", type=int, default=1, help="Length of future to optimize over",
774
    )
775
    parser.add_argument("--decay", action="store_true", help="whether to decay or not")
776
    parser.add_argument("--gamma", type=float, default=1.5)
777
778
779
780
    parser.add_argument("--gm_scale", type=float, default=0.9)
    parser.add_argument("--kl_scale", type=float, default=0.01)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--no_cuda", action="store_true", help="no cuda")
781
    parser.add_argument("--colorama", action="store_true", help="colors keywords")
782
783
784
    parser.add_argument(
        "--repetition_penalty", type=float, default=1.0, help="Penalize repetition. More than 1.0 -> less repetition",
    )
785
786
787

    args = parser.parse_args()
    run_pplm_example(**vars(args))