"vscode:/vscode.git/clone" did not exist on "227f1a74bb0fbdd39b737a6e6ab75d0c61f3d6fa"
run_pplm.py 27.6 KB
Newer Older
Piero Molino's avatar
Piero Molino committed
1
#! /usr/bin/env python3
Julien Chaumond's avatar
Julien Chaumond committed
2
# coding=utf-8
Rosanne Liu's avatar
Rosanne Liu committed
3

4
# Copyright (c) 2019 Uber Technologies, Inc.
Julien Chaumond's avatar
Julien Chaumond committed
5
#
6
7
8
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Julien Chaumond's avatar
Julien Chaumond committed
9
#
10
# http://www.apache.org/licenses/LICENSE-2.0
Julien Chaumond's avatar
Julien Chaumond committed
11
#
12
13
14
15
16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Julien Chaumond's avatar
Julien Chaumond committed
17
18
19

"""
Example command with bag of words:
20
python run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
Julien Chaumond's avatar
Julien Chaumond committed
21
22

Example command with discriminator:
23
python run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
Julien Chaumond's avatar
Julien Chaumond committed
24
25
26
"""

import argparse
27
import json
Julien Chaumond's avatar
Julien Chaumond committed
28
29
30
31
32
33
34
35
from operator import add
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import trange

Aymeric Augustin's avatar
Aymeric Augustin committed
36
from pplm_classification_head import ClassificationHead
Julien Chaumond's avatar
Julien Chaumond committed
37
38
39
from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel
Aymeric Augustin's avatar
Aymeric Augustin committed
40

Julien Chaumond's avatar
Julien Chaumond committed
41
42
43
44
45

PPLM_BOW = 1
PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
46
BIG_CONST = 1e10
Julien Chaumond's avatar
Julien Chaumond committed
47
48

BAG_OF_WORDS_ARCHIVE_MAP = {
49
50
51
52
53
54
55
    "legal": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
    "military": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
    "politics": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
    "religion": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
    "science": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
    "space": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
    "technology": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
Julien Chaumond's avatar
Julien Chaumond committed
56
57
58
59
}

DISCRIMINATOR_MODELS_PARAMS = {
    "clickbait": {
Julien Chaumond's avatar
Julien Chaumond committed
60
        "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt",
Julien Chaumond's avatar
Julien Chaumond committed
61
62
63
64
        "class_size": 2,
        "embed_size": 1024,
        "class_vocab": {"non_clickbait": 0, "clickbait": 1},
        "default_class": 1,
65
        "pretrained_model": "gpt2-medium",
Julien Chaumond's avatar
Julien Chaumond committed
66
67
    },
    "sentiment": {
Julien Chaumond's avatar
Julien Chaumond committed
68
        "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt",
Julien Chaumond's avatar
Julien Chaumond committed
69
70
71
72
        "class_size": 5,
        "embed_size": 1024,
        "class_vocab": {"very_positive": 2, "very_negative": 3},
        "default_class": 3,
73
        "pretrained_model": "gpt2-medium",
Julien Chaumond's avatar
Julien Chaumond committed
74
75
76
77
    },
}


Piero Molino's avatar
Piero Molino committed
78
79
80
81
82
83
84
85
86
87
88
89
def top_k_filter(logits, k, probs=False):
    """
    Masks everything but the k top entries as -infinity (1e10).
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
    if k == 0:
        return logits
    else:
        values = torch.topk(logits, k)[0]
        batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
        if probs:
90
91
            return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits)
        return torch.where(logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits)
Piero Molino's avatar
Piero Molino committed
92
93


94
def perturb_past(
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    past,
    model,
    last,
    unpert_past=None,
    unpert_logits=None,
    accumulated_hidden=None,
    grad_norms=None,
    stepsize=0.01,
    one_hot_bows_vectors=None,
    classifier=None,
    class_label=None,
    loss_type=0,
    num_iterations=3,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    kl_scale=0.01,
    device="cuda",
114
):
Piero Molino's avatar
Piero Molino committed
115
    # Generate inital perturbed past
116
    grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past]
Julien Chaumond's avatar
Julien Chaumond committed
117
118
119
120

    if accumulated_hidden is None:
        accumulated_hidden = 0

121
    if decay:
122
        decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (window_length))[1:]
Julien Chaumond's avatar
Julien Chaumond committed
123
124
125
    else:
        decay_mask = 1.0

126
    # TODO fix this comment (SUMANTH)
Piero Molino's avatar
Piero Molino committed
127
    # Generate a mask is gradient perturbated is based on a past window
128
    _, _, _, curr_length, _ = past[0].shape
Piero Molino's avatar
Piero Molino committed
129

130
    if curr_length > window_length and window_length > 0:
131
        ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:])
Piero Molino's avatar
Piero Molino committed
132

133
        zeros_key_val_shape = (
134
            tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
135
        )
Julien Chaumond's avatar
Julien Chaumond committed
136
137
138
139
140

        ones_mask = torch.ones(ones_key_val_shape)
        ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
        ones_mask = ones_mask.permute(0, 1, 2, 4, 3)

141
        window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).to(device)
Julien Chaumond's avatar
Julien Chaumond committed
142
    else:
143
        window_mask = torch.ones_like(past[0]).to(device)
Julien Chaumond's avatar
Julien Chaumond committed
144

145
    # accumulate perturbations for num_iterations
Julien Chaumond's avatar
Julien Chaumond committed
146
    loss_per_iter = []
147
    new_accumulated_hidden = None
148
    for i in range(num_iterations):
Julien Chaumond's avatar
Julien Chaumond committed
149
        print("Iteration ", i + 1)
150
        curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
songyouwei's avatar
songyouwei committed
151
152
153
        # make sure p_.grad is not None
        for p_ in curr_perturbation:
            p_.retain_grad()
154
155
156
157
158

        # Compute hidden using perturbed past
        perturbed_past = list(map(add, past, curr_perturbation))
        _, _, _, curr_length, _ = curr_perturbation[0].shape
        all_logits, _, all_hidden = model(last, past=perturbed_past)
Piero Molino's avatar
Piero Molino committed
159
        hidden = all_hidden[-1]
160
        new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach()
161
162
163
        # TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
        logits = all_logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
Piero Molino's avatar
Piero Molino committed
164
165
166

        loss = 0.0
        loss_list = []
167
168
169
170
171
172
        if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
            for one_hot_bow in one_hot_bows_vectors:
                bow_logits = torch.mm(probs, torch.t(one_hot_bow))
                bow_loss = -torch.log(torch.sum(bow_logits))
                loss += bow_loss
                loss_list.append(bow_loss)
Piero Molino's avatar
Piero Molino committed
173
174
            print(" pplm_bow_loss:", loss.data.cpu().numpy())

175
        if loss_type == 2 or loss_type == 3:
Julien Chaumond's avatar
Julien Chaumond committed
176
            ce_loss = torch.nn.CrossEntropyLoss()
177
178
179
180
181
182
            # TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
            curr_unpert_past = unpert_past
            curr_probs = torch.unsqueeze(probs, dim=1)
            wte = model.resize_token_embeddings()
            for _ in range(horizon_length):
                inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
183
                _, curr_unpert_past, curr_all_hidden = model(past=curr_unpert_past, inputs_embeds=inputs_embeds)
184
                curr_hidden = curr_all_hidden[-1]
185
                new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1)
Julien Chaumond's avatar
Julien Chaumond committed
186

187
            prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length))
Julien Chaumond's avatar
Julien Chaumond committed
188

189
            label = torch.tensor(prediction.shape[0] * [class_label], device=device, dtype=torch.long)
190
            discrim_loss = ce_loss(prediction, label)
Julien Chaumond's avatar
Julien Chaumond committed
191
            print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
Piero Molino's avatar
Piero Molino committed
192
193
            loss += discrim_loss
            loss_list.append(discrim_loss)
Julien Chaumond's avatar
Julien Chaumond committed
194

Piero Molino's avatar
Piero Molino committed
195
196
        kl_loss = 0.0
        if kl_scale > 0.0:
197
            unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
198
199
            unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach()
            correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
200
            corrected_probs = probs + correction.detach()
201
202
            kl_loss = kl_scale * ((corrected_probs * (corrected_probs / unpert_probs).log()).sum())
            print(" kl_loss", kl_loss.data.cpu().numpy())
203
            loss += kl_loss
Julien Chaumond's avatar
Julien Chaumond committed
204
205

        loss_per_iter.append(loss.data.cpu().numpy())
206
        print(" pplm_loss", (loss - kl_loss).data.cpu().numpy())
Julien Chaumond's avatar
Julien Chaumond committed
207

208
        # compute gradients
Rosanne Liu's avatar
Rosanne Liu committed
209
        loss.backward()
210
211
212

        # calculate gradient norms
        if grad_norms is not None and loss_type == PPLM_BOW:
Julien Chaumond's avatar
Julien Chaumond committed
213
214
            grad_norms = [
                torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
215
216
                for index, p_ in enumerate(curr_perturbation)
            ]
Julien Chaumond's avatar
Julien Chaumond committed
217
        else:
218
            grad_norms = [
219
                (torch.norm(p_.grad * window_mask) + SMALL_CONST) for index, p_ in enumerate(curr_perturbation)
220
            ]
Julien Chaumond's avatar
Julien Chaumond committed
221

222
        # normalize gradients
Julien Chaumond's avatar
Julien Chaumond committed
223
        grad = [
224
            -stepsize * (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
225
226
            for index, p_ in enumerate(curr_perturbation)
        ]
Julien Chaumond's avatar
Julien Chaumond committed
227

228
229
230
231
232
        # accumulate gradient
        grad_accumulator = list(map(add, grad, grad_accumulator))

        # reset gradients, just to make sure
        for p_ in curr_perturbation:
Julien Chaumond's avatar
Julien Chaumond committed
233
234
            p_.grad.data.zero_()

235
        # removing past from the graph
Julien Chaumond's avatar
Julien Chaumond committed
236
        new_past = []
237
238
        for p_ in past:
            new_past.append(p_.detach())
Julien Chaumond's avatar
Julien Chaumond committed
239
240
        past = new_past

241
    # apply the accumulated perturbations to the past
242
    grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
243
    pert_past = list(map(add, past, grad_accumulator))
Julien Chaumond's avatar
Julien Chaumond committed
244

245
    return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
Julien Chaumond's avatar
Julien Chaumond committed
246
247
248


def get_classifier(
249
    name: Optional[str], class_label: Union[str, int], device: str
Julien Chaumond's avatar
Julien Chaumond committed
250
251
252
253
254
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
    if name is None:
        return None, None

    params = DISCRIMINATOR_MODELS_PARAMS[name]
255
    classifier = ClassificationHead(class_size=params["class_size"], embed_size=params["embed_size"]).to(device)
256
257
    if "url" in params:
        resolved_archive_file = cached_path(params["url"])
258
    elif "path" in params:
259
        resolved_archive_file = params["path"]
260
    else:
261
        raise ValueError("Either url or path have to be specified in the discriminator model parameters")
262
    classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device))
Julien Chaumond's avatar
Julien Chaumond committed
263
264
    classifier.eval()

265
266
267
    if isinstance(class_label, str):
        if class_label in params["class_vocab"]:
            label_id = params["class_vocab"][class_label]
Julien Chaumond's avatar
Julien Chaumond committed
268
269
        else:
            label_id = params["default_class"]
270
            print("class_label {} not in class_vocab".format(class_label))
Julien Chaumond's avatar
Julien Chaumond committed
271
272
273
            print("available values are: {}".format(params["class_vocab"]))
            print("using default class {}".format(label_id))

274
275
276
    elif isinstance(class_label, int):
        if class_label in set(params["class_vocab"].values()):
            label_id = class_label
Julien Chaumond's avatar
Julien Chaumond committed
277
278
        else:
            label_id = params["default_class"]
279
            print("class_label {} not in class_vocab".format(class_label))
Julien Chaumond's avatar
Julien Chaumond committed
280
281
282
283
284
285
286
287
288
            print("available values are: {}".format(params["class_vocab"]))
            print("using default class {}".format(label_id))

    else:
        label_id = params["default_class"]

    return classifier, label_id


289
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> List[List[List[int]]]:
Julien Chaumond's avatar
Julien Chaumond committed
290
291
292
293
294
295
296
    bow_indices = []
    for id_or_path in bag_of_words_ids_or_paths:
        if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
            filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path])
        else:
            filepath = id_or_path
        with open(filepath, "r") as f:
Piero Molino's avatar
Piero Molino committed
297
            words = f.read().strip().split("\n")
298
        bow_indices.append([tokenizer.encode(word.strip(), add_prefix_space=True) for word in words])
Julien Chaumond's avatar
Julien Chaumond committed
299
300
301
    return bow_indices


302
def build_bows_one_hot_vectors(bow_indices, tokenizer, device="cuda"):
Julien Chaumond's avatar
Julien Chaumond committed
303
304
305
306
307
308
    if bow_indices is None:
        return None

    one_hot_bows_vectors = []
    for single_bow in bow_indices:
        single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
309
        single_bow = torch.tensor(single_bow).to(device)
Julien Chaumond's avatar
Julien Chaumond committed
310
        num_words = single_bow.shape[0]
311
        one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
Julien Chaumond's avatar
Julien Chaumond committed
312
313
314
315
316
        one_hot_bow.scatter_(1, single_bow, 1)
        one_hot_bows_vectors.append(one_hot_bow)
    return one_hot_bows_vectors


317
def full_text_generation(
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    model,
    tokenizer,
    context=None,
    num_samples=1,
    device="cuda",
    bag_of_words=None,
    discrim=None,
    class_label=None,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
339
    repetition_penalty=1.0,
340
    **kwargs
341
):
342
    classifier, class_id = get_classifier(discrim, class_label, device)
Julien Chaumond's avatar
Julien Chaumond committed
343

344
345
    bow_indices = []
    if bag_of_words:
346
        bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
Piero Molino's avatar
Piero Molino committed
347

348
    if bag_of_words and classifier:
Julien Chaumond's avatar
Julien Chaumond committed
349
        print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
350
        loss_type = PPLM_BOW_DISCRIM
Julien Chaumond's avatar
Julien Chaumond committed
351

352
353
    elif bag_of_words:
        loss_type = PPLM_BOW
Julien Chaumond's avatar
Julien Chaumond committed
354
355
356
        print("Using PPLM-BoW")

    elif classifier is not None:
357
        loss_type = PPLM_DISCRIM
Julien Chaumond's avatar
Julien Chaumond committed
358
359
360
        print("Using PPLM-Discrim")

    else:
361
        raise Exception("Specify either a bag of words or a discriminator")
Julien Chaumond's avatar
Julien Chaumond committed
362

363
    unpert_gen_tok_text, _, _ = generate_text_pplm(
364
365
366
367
368
369
370
371
        model=model,
        tokenizer=tokenizer,
        context=context,
        device=device,
        length=length,
        sample=sample,
        perturb=False,
        repetition_penalty=repetition_penalty,
372
    )
373
    if device == "cuda":
374
        torch.cuda.empty_cache()
Julien Chaumond's avatar
Julien Chaumond committed
375

376
377
378
    pert_gen_tok_texts = []
    discrim_losses = []
    losses_in_time = []
Piero Molino's avatar
Piero Molino committed
379

380
    for i in range(num_samples):
381
        pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
382
            model=model,
383
            tokenizer=tokenizer,
384
385
386
387
388
            context=context,
            device=device,
            perturb=True,
            bow_indices=bow_indices,
            classifier=classifier,
389
            class_label=class_id,
390
391
392
393
394
            loss_type=loss_type,
            length=length,
            stepsize=stepsize,
            temperature=temperature,
            top_k=top_k,
395
396
397
            sample=sample,
            num_iterations=num_iterations,
            grad_length=grad_length,
398
            horizon_length=horizon_length,
399
            window_length=window_length,
400
401
            decay=decay,
            gamma=gamma,
402
403
            gm_scale=gm_scale,
            kl_scale=kl_scale,
404
            repetition_penalty=repetition_penalty,
405
        )
406
        pert_gen_tok_texts.append(pert_gen_tok_text)
Julien Chaumond's avatar
Julien Chaumond committed
407
        if classifier is not None:
408
409
            discrim_losses.append(discrim_loss.data.cpu().numpy())
        losses_in_time.append(loss_in_time)
Julien Chaumond's avatar
Julien Chaumond committed
410

411
    if device == "cuda":
412
        torch.cuda.empty_cache()
Julien Chaumond's avatar
Julien Chaumond committed
413

414
    return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
Julien Chaumond's avatar
Julien Chaumond committed
415

416
417

def generate_text_pplm(
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    model,
    tokenizer,
    context=None,
    past=None,
    device="cuda",
    perturb=True,
    bow_indices=None,
    classifier=None,
    class_label=None,
    loss_type=0,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
441
    repetition_penalty=1.0,
442
):
443
444
445
446
447
448
    output_so_far = None
    if context:
        context_t = torch.tensor(context, device=device, dtype=torch.long)
        while len(context_t.shape) < 2:
            context_t = context_t.unsqueeze(0)
        output_so_far = context_t
Julien Chaumond's avatar
Julien Chaumond committed
449

450
    # collect one hot vectors for bags of words
451
    one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device)
452

Julien Chaumond's avatar
Julien Chaumond committed
453
    grad_norms = None
454
    last = None
455
    unpert_discrim_loss = 0
Julien Chaumond's avatar
Julien Chaumond committed
456
    loss_in_time = []
457
    for i in trange(length, ascii=True):
Julien Chaumond's avatar
Julien Chaumond committed
458
459

        # Get past/probs for current output, except for last word
460
        # Note that GPT takes 2 inputs: past + current_token
Julien Chaumond's avatar
Julien Chaumond committed
461

462
463
464
        # run model forward to obtain unperturbed
        if past is None and output_so_far is not None:
            last = output_so_far[:, -1:]
465
466
            if output_so_far.shape[1] > 1:
                _, past, _ = model(output_so_far[:, :-1])
Piero Molino's avatar
Piero Molino committed
467

468
469
        unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
        unpert_last_hidden = unpert_all_hidden[-1]
Piero Molino's avatar
Piero Molino committed
470

471
        # check if we are abowe grad max length
472
473
        if i >= grad_length:
            current_stepsize = stepsize * 0
Julien Chaumond's avatar
Julien Chaumond committed
474
        else:
475
            current_stepsize = stepsize
Julien Chaumond's avatar
Julien Chaumond committed
476

477
        # modify the past if necessary
478
        if not perturb or num_iterations == 0:
479
            pert_past = past
Julien Chaumond's avatar
Julien Chaumond committed
480
481

        else:
482
            accumulated_hidden = unpert_last_hidden[:, :-1, :]
Julien Chaumond's avatar
Julien Chaumond committed
483
484
            accumulated_hidden = torch.sum(accumulated_hidden, dim=1)

485
486
487
488
489
490
491
492
493
494
            if past is not None:
                pert_past, _, grad_norms, loss_this_iter = perturb_past(
                    past,
                    model,
                    last,
                    unpert_past=unpert_past,
                    unpert_logits=unpert_logits,
                    accumulated_hidden=accumulated_hidden,
                    grad_norms=grad_norms,
                    stepsize=current_stepsize,
495
                    one_hot_bows_vectors=one_hot_bows_vectors,
496
                    classifier=classifier,
497
                    class_label=class_label,
498
499
500
                    loss_type=loss_type,
                    num_iterations=num_iterations,
                    horizon_length=horizon_length,
501
                    window_length=window_length,
502
503
                    decay=decay,
                    gamma=gamma,
504
505
                    kl_scale=kl_scale,
                    device=device,
506
507
508
509
                )
                loss_in_time.append(loss_this_iter)
            else:
                pert_past = past
Piero Molino's avatar
Piero Molino committed
510

511
512
        pert_logits, past, pert_all_hidden = model(last, past=pert_past)
        pert_logits = pert_logits[:, -1, :] / temperature  # + SMALL_CONST
513
514
515
516
517
518
519

        for token_idx in set(output_so_far[0].tolist()):
            if pert_logits[0, token_idx] < 0:
                pert_logits[0, token_idx] *= repetition_penalty
            else:
                pert_logits[0, token_idx] /= repetition_penalty

520
        pert_probs = F.softmax(pert_logits, dim=-1)
Julien Chaumond's avatar
Julien Chaumond committed
521
522

        if classifier is not None:
Piero Molino's avatar
Piero Molino committed
523
            ce_loss = torch.nn.CrossEntropyLoss()
524
            prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
525
            label = torch.tensor([class_label], device=device, dtype=torch.long)
526
            unpert_discrim_loss = ce_loss(prediction, label)
527
            print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy())
Julien Chaumond's avatar
Julien Chaumond committed
528
        else:
529
            unpert_discrim_loss = 0
Piero Molino's avatar
Piero Molino committed
530
531

        # Fuse the modified model and original model
Julien Chaumond's avatar
Julien Chaumond committed
532
533
        if perturb:

534
            unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
Piero Molino's avatar
Piero Molino committed
535

536
537
            pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale))  # + SMALL_CONST
            pert_probs = top_k_filter(pert_probs, k=top_k, probs=True)  # + SMALL_CONST
Julien Chaumond's avatar
Julien Chaumond committed
538

539
540
541
            # rescale
            if torch.sum(pert_probs) <= 1:
                pert_probs = pert_probs / torch.sum(pert_probs)
Julien Chaumond's avatar
Julien Chaumond committed
542
543

        else:
544
545
            pert_logits = top_k_filter(pert_logits, k=top_k)  # + SMALL_CONST
            pert_probs = F.softmax(pert_logits, dim=-1)
Julien Chaumond's avatar
Julien Chaumond committed
546

547
        # sample or greedy
Julien Chaumond's avatar
Julien Chaumond committed
548
        if sample:
549
550
            last = torch.multinomial(pert_probs, num_samples=1)

Julien Chaumond's avatar
Julien Chaumond committed
551
        else:
552
553
554
            _, last = torch.topk(pert_probs, k=1, dim=-1)

        # update context/output_so_far appending the new token
555
        output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1)
556

557
        print(tokenizer.decode(output_so_far.tolist()[0]))
558
559

    return output_so_far, unpert_discrim_loss, loss_in_time
Julien Chaumond's avatar
Julien Chaumond committed
560
561


562
563
def set_generic_model_params(discrim_weights, discrim_meta):
    if discrim_weights is None:
564
        raise ValueError("When using a generic discriminator, discrim_weights need to be specified")
565
    if discrim_meta is None:
566
        raise ValueError("When using a generic discriminator, discrim_meta need to be specified")
567

568
    with open(discrim_meta, "r") as discrim_meta_file:
569
        meta = json.load(discrim_meta_file)
570
571
    meta["path"] = discrim_weights
    DISCRIMINATOR_MODELS_PARAMS["generic"] = meta
572
573


574
def run_pplm_example(
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
    pretrained_model="gpt2-medium",
    cond_text="",
    uncond=False,
    num_samples=1,
    bag_of_words=None,
    discrim=None,
    discrim_weights=None,
    discrim_meta=None,
    class_label=-1,
    length=100,
    stepsize=0.02,
    temperature=1.0,
    top_k=10,
    sample=False,
    num_iterations=3,
    grad_length=10000,
    horizon_length=1,
    window_length=0,
    decay=False,
    gamma=1.5,
    gm_scale=0.9,
    kl_scale=0.01,
    seed=0,
    no_cuda=False,
    colorama=False,
600
    repetition_penalty=1.0,
601
):
602
    # set Random seed
603
604
    torch.manual_seed(seed)
    np.random.seed(seed)
Julien Chaumond's avatar
Julien Chaumond committed
605

606
    # set the device
607
608
    device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"

609
    if discrim == "generic":
610
        set_generic_model_params(discrim_weights, discrim_meta)
Julien Chaumond's avatar
Julien Chaumond committed
611

612
    if discrim is not None:
613
        pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"]
614
        print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model))
615

616
    # load pretrained model
617
    model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True)
Julien Chaumond's avatar
Julien Chaumond committed
618
619
620
    model.to(device)
    model.eval()

621
622
623
    # load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)

Piero Molino's avatar
Piero Molino committed
624
    # Freeze GPT-2 weights
Julien Chaumond's avatar
Julien Chaumond committed
625
626
627
    for param in model.parameters():
        param.requires_grad = False

628
    # figure out conditioning text
629
    if uncond:
630
        tokenized_cond_text = tokenizer.encode([tokenizer.bos_token])
Julien Chaumond's avatar
Julien Chaumond committed
631
    else:
632
        raw_text = cond_text
Julien Chaumond's avatar
Julien Chaumond committed
633
        while not raw_text:
634
            print("Did you forget to add `--cond_text`? ")
Julien Chaumond's avatar
Julien Chaumond committed
635
            raw_text = input("Model prompt >>> ")
636
        tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)
Piero Molino's avatar
Piero Molino committed
637

638
    print("= Prefix of sentence =")
639
    print(tokenizer.decode(tokenized_cond_text))
640
    print()
Piero Molino's avatar
Piero Molino committed
641

642
    # generate unperturbed and perturbed texts
Piero Molino's avatar
Piero Molino committed
643

644
645
646
    # full_text_generation returns:
    # unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
    unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
        model=model,
        tokenizer=tokenizer,
        context=tokenized_cond_text,
        device=device,
        num_samples=num_samples,
        bag_of_words=bag_of_words,
        discrim=discrim,
        class_label=class_label,
        length=length,
        stepsize=stepsize,
        temperature=temperature,
        top_k=top_k,
        sample=sample,
        num_iterations=num_iterations,
        grad_length=grad_length,
        horizon_length=horizon_length,
        window_length=window_length,
        decay=decay,
        gamma=gamma,
        gm_scale=gm_scale,
        kl_scale=kl_scale,
668
        repetition_penalty=repetition_penalty,
669
670
671
    )

    # untokenize unperturbed text
672
    unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])
Piero Molino's avatar
Piero Molino committed
673

674
675
676
677
    print("=" * 80)
    print("= Unperturbed generated text =")
    print(unpert_gen_text)
    print()
Piero Molino's avatar
Piero Molino committed
678

679
680
    generated_texts = []

681
    bow_word_ids = set()
682
    if bag_of_words and colorama:
683
        bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer)
684
685
686
687
688
        for single_bow_list in bow_indices:
            # filtering all words in the list composed of more than 1 token
            filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
            # w[0] because we are sure w has only 1 item because previous fitler
            bow_word_ids.update(w[0] for w in filtered)
689
690
691
692
693

    # iterate through the perturbed texts
    for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
        try:
            # untokenize unperturbed text
694
            if colorama:
Piero Molino's avatar
Piero Molino committed
695
696
                import colorama

697
                pert_gen_text = ""
698
                for word_id in pert_gen_tok_text.tolist()[0]:
699
                    if word_id in bow_word_ids:
700
                        pert_gen_text += "{}{}{}".format(
Lysandre's avatar
Lysandre committed
701
702
703
                            colorama.Fore.RED,
                            tokenizer.decode([word_id]),
                            colorama.Style.RESET_ALL,
704
                        )
Piero Molino's avatar
Piero Molino committed
705
                    else:
706
                        pert_gen_text += tokenizer.decode([word_id])
Piero Molino's avatar
Piero Molino committed
707
            else:
708
                pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
Julien Chaumond's avatar
Julien Chaumond committed
709

710
711
712
            print("= Perturbed generated text {} =".format(i + 1))
            print(pert_gen_text)
            print()
713
714
        except Exception as exc:
            print("Ignoring error while generating perturbed text:", exc)
Julien Chaumond's avatar
Julien Chaumond committed
715

716
        # keep the prefix, perturbed seq, original seq for each index
717
        generated_texts.append((tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text))
Julien Chaumond's avatar
Julien Chaumond committed
718

Piero Molino's avatar
Piero Molino committed
719
    return
Julien Chaumond's avatar
Julien Chaumond committed
720
721


722
if __name__ == "__main__":
723
724
725
726
727
728
729
730
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--pretrained_model",
        "-M",
        type=str,
        default="gpt2-medium",
        help="pretrained model name or path to local checkpoint",
    )
731
732
    parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on")
    parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix")
733
    parser.add_argument(
Lysandre's avatar
Lysandre committed
734
735
736
737
        "--num_samples",
        type=int,
        default=1,
        help="Number of samples to generate from the modified latents",
738
    )
739
740
741
742
743
    parser.add_argument(
        "--bag_of_words",
        "-B",
        type=str,
        default=None,
744
745
746
747
748
        help=(
            "Bags of words used for PPLM-BoW. "
            "Either a BOW id (see list in code) or a filepath. "
            "Multiple BoWs separated by ;"
        ),
749
750
751
752
753
754
755
756
757
758
    )
    parser.add_argument(
        "--discrim",
        "-D",
        type=str,
        default=None,
        choices=("clickbait", "sentiment", "toxicity", "generic"),
        help="Discriminator to use",
    )
    parser.add_argument(
Lysandre's avatar
Lysandre committed
759
760
761
762
        "--discrim_weights",
        type=str,
        default=None,
        help="Weights for the generic discriminator",
763
764
    )
    parser.add_argument(
Lysandre's avatar
Lysandre committed
765
766
767
768
        "--discrim_meta",
        type=str,
        default=None,
        help="Meta information for the generic discriminator",
769
770
    )
    parser.add_argument(
Lysandre's avatar
Lysandre committed
771
772
773
774
        "--class_label",
        type=int,
        default=-1,
        help="Class label used for the discriminator",
775
776
    )
    parser.add_argument("--length", type=int, default=100)
777
    parser.add_argument("--stepsize", type=float, default=0.02)
778
779
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=10)
780
    parser.add_argument("--sample", action="store_true", help="Generate from end-of-text as prefix")
781
782
783
    parser.add_argument("--num_iterations", type=int, default=3)
    parser.add_argument("--grad_length", type=int, default=10000)
    parser.add_argument(
784
        "--window_length",
785
        type=int,
786
        default=0,
787
        help="Length of past which is being optimized; 0 corresponds to infinite window length",
788
789
    )
    parser.add_argument(
Lysandre's avatar
Lysandre committed
790
791
792
793
        "--horizon_length",
        type=int,
        default=1,
        help="Length of future to optimize over",
794
    )
795
    parser.add_argument("--decay", action="store_true", help="whether to decay or not")
796
    parser.add_argument("--gamma", type=float, default=1.5)
797
798
799
800
    parser.add_argument("--gm_scale", type=float, default=0.9)
    parser.add_argument("--kl_scale", type=float, default=0.01)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--no_cuda", action="store_true", help="no cuda")
801
    parser.add_argument("--colorama", action="store_true", help="colors keywords")
802
    parser.add_argument(
Lysandre's avatar
Lysandre committed
803
804
805
806
        "--repetition_penalty",
        type=float,
        default=1.0,
        help="Penalize repetition. More than 1.0 -> less repetition",
807
    )
808
809
810

    args = parser.parse_args()
    run_pplm_example(**vars(args))