Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
34a83faa
Commit
34a83faa
authored
Nov 25, 2019
by
Piero Molino
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Let's make PPLM great again
parent
d5faa74c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
452 additions
and
480 deletions
+452
-480
examples/run_pplm.py
examples/run_pplm.py
+452
-480
No files found.
examples/run_pplm.py
View file @
34a83faa
#! /usr/bin/env python3
# coding=utf-8
# Copyright 2018 The Uber AI Team Authors.
#
...
...
@@ -37,10 +38,12 @@ from transformers import GPT2Tokenizer
from
transformers.file_utils
import
cached_path
from
transformers.modeling_gpt2
import
GPT2LMHeadModel
PPLM_BOW
=
1
PPLM_DISCRIM
=
2
PPLM_BOW_DISCRIM
=
3
SMALL_CONST
=
1e-15
SmallConst
=
1e-15
TOKENIZER
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2-medium"
)
BAG_OF_WORDS_ARCHIVE_MAP
=
{
...
...
@@ -65,7 +68,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"default_class"
:
1
,
},
"sentiment"
:
{
"url"
:
"http
s
://s
3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/sentiment
_classifierhead.pt"
,
"url"
:
"http://s
.yosinski.com/SST
_classifier
_
head.pt"
,
"class_size"
:
5
,
"embed_size"
:
1024
,
"class_vocab"
:
{
"very_positive"
:
2
,
"very_negative"
:
3
},
...
...
@@ -81,24 +84,6 @@ DISCRIMINATOR_MODELS_PARAMS = {
}
class
ClassificationHead
(
torch
.
nn
.
Module
):
""" Classification Head for the transformer """
def
__init__
(
self
,
class_size
=
5
,
embed_size
=
2048
):
super
(
ClassificationHead
,
self
).
__init__
()
self
.
class_size
=
class_size
self
.
embed_size
=
embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self
.
mlp
=
torch
.
nn
.
Linear
(
embed_size
,
class_size
)
def
forward
(
self
,
hidden_state
):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits
=
self
.
mlp
(
hidden_state
)
return
logits
def
to_var
(
x
,
requires_grad
=
False
,
volatile
=
False
):
if
torch
.
cuda
.
is_available
():
x
=
x
.
cuda
()
...
...
@@ -111,222 +96,205 @@ def top_k_filter(logits, k, probs=False):
Used to mask logits such that e^-infinity -> 0 won't contribute to the
sum of the denominator.
"""
if
k
<
=
0
:
if
k
=
=
0
:
return
logits
else
:
values
=
torch
.
topk
(
logits
,
k
)[
0
]
batch_mins
=
values
[:,
-
1
].
view
(
-
1
,
1
).
expand_as
(
logits
)
if
probs
:
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
0.0
,
logits
)
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
0.0
,
logits
)
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
-
1e10
,
logits
)
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
-
1e10
,
logits
)
class
ClassificationHead
(
torch
.
nn
.
Module
):
""" Classification Head for the transformer """
def
__init__
(
self
,
class_size
=
5
,
embed_size
=
2048
):
super
(
ClassificationHead
,
self
).
__init__
()
self
.
class_size
=
class_size
self
.
embed_size
=
embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self
.
mlp
=
torch
.
nn
.
Linear
(
embed_size
,
class_size
)
def
forward
(
self
,
hidden_state
):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits
=
self
.
mlp
(
hidden_state
)
return
logits
def
perturb_past
(
past
,
model
,
last
,
unpert_past
=
None
,
unpert_logits
=
None
,
accumulated_hidden
=
None
,
grad_norms
=
None
,
stepsize
=
0.01
,
classifier
=
None
,
label_class
=
None
,
one_hot_bows_vectors
=
None
,
loss_type
=
0
,
num_iterations
=
3
,
kl_scale
=
0.01
,
window_length
=
0
,
horizon_length
=
1
,
decay
=
False
,
gamma
=
1.5
,
):
# initializie perturbation accumulator
grad_accumulator
=
[
(
np
.
zeros
(
p
.
shape
).
astype
(
"float32"
))
for
p
in
past
]
def
perturb_past
(
past
,
model
,
prev
,
args
,
classifier
,
good_index
=
None
,
stepsize
=
0.01
,
vocab_size
=
50257
,
original_probs
=
None
,
accumulated_hidden
=
None
,
true_past
=
None
,
grad_norms
=
None
):
window_length
=
args
.
window_length
gm_scale
,
kl_scale
=
args
.
gm_scale
,
args
.
kl_scale
one_hot_vectors
=
[]
for
good_list
in
good_index
:
good_list
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
good_list
))
good_list
=
torch
.
tensor
(
good_list
).
cuda
()
num_good
=
good_list
.
shape
[
0
]
one_hot_good
=
torch
.
zeros
(
num_good
,
vocab_size
).
cuda
()
one_hot_good
.
scatter_
(
1
,
good_list
,
1
)
one_hot_vectors
.
append
(
one_hot_good
)
# Generate inital perturbed past
past_perturb_orig
=
[
(
np
.
random
.
uniform
(
0.0
,
0.0
,
p
.
shape
).
astype
(
'float32'
))
for
p
in
past
]
if
accumulated_hidden
is
None
:
accumulated_hidden
=
0
if
decay
:
decay_mask
=
torch
.
arange
(
0.0
,
1.0
+
SMALL_CONST
,
1.0
/
(
window_length
)
)[
1
:]
if
args
.
decay
:
decay_mask
=
torch
.
arange
(
0.
,
1.0
+
SmallConst
,
1.0
/
(
window_length
))[
1
:]
else
:
decay_mask
=
1.0
# TODO fix this comment (SUMANTH)
# generate a mask if perturbated gradient is based on a past window
_
,
_
,
_
,
curr_length
,
_
=
past
[
0
].
shape
if
curr_length
>
window_length
and
window_length
>
0
:
ones_key_val_shape
=
(
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
([
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
# Generate a mask is gradient perturbated is based on a past window
_
,
_
,
_
,
current_length
,
_
=
past
[
0
].
shape
zeros_key_val_shape
=
(
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
([
curr_length
-
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
if
current_length
>
window_length
and
window_length
>
0
:
ones_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
(
[
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
zeros_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
(
[
current_length
-
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
ones_mask
=
torch
.
ones
(
ones_key_val_shape
)
ones_mask
=
decay_mask
*
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
ones_mask
=
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
window_mask
=
torch
.
cat
(
(
ones_mask
,
torch
.
zeros
(
zeros_key_val_shape
)),
dim
=-
2
).
cuda
()
window_mask
=
torch
.
cat
((
ones_mask
,
torch
.
zeros
(
zeros_key_val_shape
)),
dim
=-
2
).
cuda
()
else
:
window_mask
=
torch
.
ones_like
(
past
[
0
]).
cuda
()
# accumulate perturbations for num_iterations
loss_per_iter
=
[]
for
i
in
range
(
num_iterations
):
for
i
in
range
(
args
.
num_iterations
):
print
(
"Iteration "
,
i
+
1
)
past_perturb
=
[
torch
.
from_numpy
(
p_
)
for
p_
in
past_perturb_orig
]
past_perturb
=
[
to_var
(
p_
,
requires_grad
=
True
)
for
p_
in
past_perturb
]
curr_perturbation
=
[
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
for
p_
in
grad_accumulator
]
perturbed_past
=
list
(
map
(
add
,
past
,
past_perturb
))
# Compute hidden using perturbed past
curr_pert_past
=
list
(
map
(
add
,
past
,
curr_perturbation
))
all_logits
,
_
,
all_hidden
=
model
(
last
,
past
=
curr_pert_past
)
hidden
=
all_hidden
[
-
1
]
accumulated_hidden
+=
torch
.
sum
(
hidden
,
dim
=
1
).
detach
()
logits
=
all_logits
[:,
-
1
,
:]
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
_
,
_
,
_
,
current_length
,
_
=
past_perturb
[
0
].
shape
# compute loss
bow_loss
=
0.0
discrim_loss
=
0.0
kl_loss
=
0.0
# _, future_past = model(prev, past=perturbed_past)
# hidden = model.hidden_states
if
loss_type
==
PPLM_BOW
or
loss_type
==
PPLM_BOW_DISCRIM
:
for
one_hot_bow
in
one_hot_bows_vectors
:
bow_logits
=
torch
.
mm
(
probs
,
torch
.
t
(
one_hot_bow
))
bow_loss
+=
-
torch
.
log
(
torch
.
sum
(
bow_logits
))
print
(
" pplm_bow_loss:"
,
bow_loss
.
data
.
cpu
().
numpy
())
if
loss_type
==
PPLM_DISCRIM
or
loss_type
==
PPLM_BOW_DISCRIM
:
# Piero modified model call
logits
,
_
,
all_hidden
=
model
(
prev
,
past
=
perturbed_past
)
hidden
=
all_hidden
[
-
1
]
new_accumulated_hidden
=
accumulated_hidden
+
torch
.
sum
(
hidden
,
dim
=
1
).
detach
()
# TODO: Check the layer-norm consistency of this with trained discriminator
logits
=
logits
[:,
-
1
,
:]
probabs
=
F
.
softmax
(
logits
,
dim
=-
1
)
loss
=
0.0
loss_list
=
[]
if
args
.
loss_type
==
1
or
args
.
loss_type
==
3
:
for
one_hot_good
in
one_hot_vectors
:
good_logits
=
torch
.
mm
(
probabs
,
torch
.
t
(
one_hot_good
))
loss_word
=
good_logits
loss_word
=
torch
.
sum
(
loss_word
)
loss_word
=
-
torch
.
log
(
loss_word
)
# loss_word = torch.sum(loss_word) /torch.sum(one_hot_good)
loss
+=
loss_word
loss_list
.
append
(
loss_word
)
print
(
" pplm_bow_loss:"
,
loss
.
data
.
cpu
().
numpy
())
if
args
.
loss_type
==
2
or
args
.
loss_type
==
3
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
# TODO all there are for (SUMANTH)
# TODO why we need to do this assignment and not just using unpert_past?
curr_unpert_past
=
unpert_past
# Get the model's token embeddings in order to compute our own embeds from curr_probs:
new_true_past
=
true_past
for
i
in
range
(
args
.
horizon_length
):
future_probabs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# Get softmax
future_probabs
=
torch
.
unsqueeze
(
future_probabs
,
dim
=
1
)
# _, new_true_past = model(future_probabs, past=new_true_past)
# future_hidden = model.hidden_states # Get expected hidden states
# Piero modified model call
wte
=
model
.
resize_token_embeddings
()
# TODO i is never used, why do we need to do this i times instead multiplying
# torch.sum(unpert_hidden, dim=1) * horizon_length?
for
i
in
range
(
horizon_length
):
# TODO the next two lines can be done only one time, and why not using probs instead as they do not change at each iteration?
curr_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# get softmax
curr_probs
=
torch
.
unsqueeze
(
curr_probs
,
dim
=
1
)
inputs_embeds
=
torch
.
matmul
(
curr_probs
,
wte
.
weight
.
data
)
_
,
curr_unpert_past
,
curr_all_hidden
=
model
(
past
=
curr_unpert_past
,
inputs_embeds
=
torch
.
matmul
(
future_probabs
,
wte
.
weight
.
data
)
_
,
new_true_past
,
future_hidden
=
model
(
past
=
new_true_past
,
inputs_embeds
=
inputs_embeds
)
# get expected hidden states
unpert_hidden
=
curr_all_hidden
[
-
1
]
accumulated_hidden
+=
torch
.
sum
(
unpert_hidden
,
dim
=
1
).
detach
()
future_hidden
=
future_hidden
[
-
1
]
prediction
=
classifier
(
accumulated_hidden
/
(
curr_length
+
1
+
horizon_length
)
)
new_accumulated_hidden
=
new_accumulated_hidden
+
torch
.
sum
(
future_hidden
,
dim
=
1
)
label
=
torch
.
tensor
([
label_class
],
device
=
"cuda"
,
dtype
=
torch
.
long
)
discrim_loss
+=
ce_loss
(
prediction
,
label
)
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
predicted_sentiment
=
classifier
(
new_accumulated_hidden
/
(
current_length
+
1
+
args
.
horizon_length
))
if
kl_scale
>=
0.0
:
unpert_probs
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
unpert_probs
=
(
unpert_probs
+
SMALL_CONST
*
(
unpert_probs
<=
SMALL_CONST
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
)
label
=
torch
.
tensor
([
args
.
label_class
],
device
=
'cuda'
,
dtype
=
torch
.
long
)
discrim_loss
=
ce_loss
(
predicted_sentiment
,
label
)
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
loss
+=
discrim_loss
loss_list
.
append
(
discrim_loss
)
correction
=
SMALL_CONST
*
(
probs
<=
SMALL_CONST
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
corrected_probs
=
probs
+
correction
.
detach
()
kl_loss
=
0.0
if
kl_scale
>
0.0
:
p
=
(
F
.
softmax
(
original_probs
[:,
-
1
,
:],
dim
=-
1
))
p
=
p
+
SmallConst
*
(
p
<=
SmallConst
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
correction
=
SmallConst
*
(
probabs
<=
SmallConst
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
corrected_probabs
=
probabs
+
correction
.
detach
()
kl_loss
=
kl_scale
*
(
(
corrected_probs
*
(
corrected_probs
/
unpert_probs
).
log
()).
sum
()
)
(
corrected_probabs
*
(
corrected_probabs
/
p
).
log
()).
sum
())
print
(
' kl_loss'
,
(
kl_loss
).
data
.
cpu
().
numpy
())
loss
+=
kl_loss
# + discrim_loss
loss
=
bow_loss
+
discrim_loss
+
kl_loss
loss_per_iter
.
append
(
loss
.
data
.
cpu
().
numpy
())
print
(
' pplm_loss'
,
(
loss
-
kl_loss
).
data
.
cpu
().
numpy
())
# compute gradients
loss
.
backward
()
# calculate gradient norms
if
grad_norms
is
not
None
and
loss_type
==
PPLM_BOW
:
if
grad_norms
is
not
None
and
args
.
loss_type
==
1
:
grad_norms
=
[
torch
.
max
(
grad_norms
[
index
],
torch
.
norm
(
p_
.
grad
*
window_mask
))
for
index
,
p_
in
enumerate
(
curr_perturbation
)
]
for
index
,
p_
in
enumerate
(
past_perturb
)
]
else
:
grad_norms
=
[
(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
SMALL_CONST
)
for
index
,
p_
in
enumerate
(
curr_perturbation
)
]
grad_norms
=
[(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
SmallConst
)
for
index
,
p_
in
enumerate
(
past_perturb
)]
# normalize gradients
grad
=
[
-
stepsize
*
(
p_
.
grad
*
window_mask
/
grad_norms
[
index
]
**
gamma
).
data
.
cpu
().
numpy
()
for
index
,
p_
in
enumerate
(
curr_perturbation
)
]
# accumulate gradients
grad_accumulator
=
list
(
map
(
add
,
grad
,
grad_accumulator
))
-
stepsize
*
(
p_
.
grad
*
window_mask
/
grad_norms
[
index
]
**
args
.
gamma
).
data
.
cpu
().
numpy
()
for
index
,
p_
in
enumerate
(
past_perturb
)]
past_perturb_orig
=
list
(
map
(
add
,
grad
,
past_perturb_orig
))
# reset gradients, just to make sure
for
p_
in
curr_perturbation
:
for
p_
in
past_perturb
:
p_
.
grad
.
data
.
zero_
()
# removing past from the graph
new_past
=
[]
for
p_
in
past
:
new_past
.
append
(
p_
.
detach
())
for
p
in
past
:
new_past
.
append
(
p
.
detach
())
past
=
new_past
# apply the accumulated perturbations to the past
grad_accumulator
=
[
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
for
p_
in
grad_accumulator
]
pert_past
=
list
(
map
(
add
,
past
,
grad_accumulator
))
past_perturb
=
[
torch
.
from_numpy
(
p_
)
for
p_
in
past_perturb_orig
]
past_perturb
=
[
to_var
(
p_
,
requires_grad
=
True
)
for
p_
in
past_perturb
]
perturbed_past
=
list
(
map
(
add
,
past
,
past_perturb
))
return
pert_past
,
accumulated_hidden
,
grad_norms
,
loss_per_iter
return
pert
urbed
_past
,
new_
accumulated_hidden
,
grad_norms
,
loss_per_iter
def
get_classifier
(
name
:
Optional
[
str
],
label_class
:
Union
[
str
,
int
],
device
:
Union
[
str
,
torch
.
device
]
name
:
Optional
[
str
],
label_class
:
Union
[
str
,
int
],
device
:
Union
[
str
,
torch
.
device
]
)
->
Tuple
[
Optional
[
ClassificationHead
],
Optional
[
int
]]:
if
name
is
None
:
return
None
,
None
...
...
@@ -337,7 +305,8 @@ def get_classifier(
embed_size
=
params
[
'embed_size'
]
).
to
(
device
)
resolved_archive_file
=
cached_path
(
params
[
"url"
])
classifier
.
load_state_dict
(
torch
.
load
(
resolved_archive_file
,
map_location
=
device
))
classifier
.
load_state_dict
(
torch
.
load
(
resolved_archive_file
,
map_location
=
device
))
classifier
.
eval
()
if
isinstance
(
label_class
,
str
):
...
...
@@ -364,7 +333,8 @@ def get_classifier(
return
classifier
,
label_id
def
get_bag_of_words_indices
(
bag_of_words_ids_or_paths
:
List
[
str
])
->
List
[
List
[
List
[
int
]]]:
def
get_bag_of_words_indices
(
bag_of_words_ids_or_paths
:
List
[
str
])
->
List
[
List
[
List
[
int
]]]:
bow_indices
=
[]
for
id_or_path
in
bag_of_words_ids_or_paths
:
if
id_or_path
in
BAG_OF_WORDS_ARCHIVE_MAP
:
...
...
@@ -372,8 +342,10 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[List[
else
:
filepath
=
id_or_path
with
open
(
filepath
,
"r"
)
as
f
:
words
=
f
.
read
().
split
(
"
\n
"
)
bow_indices
.
append
([
TOKENIZER
.
encode
(
word
,
add_prefix_space
=
True
)
for
word
in
words
])
words
=
f
.
read
().
strip
().
split
(
"
\n
"
)
bow_indices
.
append
(
[
TOKENIZER
.
encode
(
word
.
strip
(),
add_prefix_space
=
True
)
for
word
in
words
])
return
bow_indices
...
...
@@ -392,327 +364,308 @@ def build_bows_one_hot_vectors(bow_indices):
return
one_hot_bows_vectors
def
full_text_generation
(
model
,
context
=
None
,
num_samples
=
1
,
device
=
"cuda"
,
sample
=
True
,
discrim
=
None
,
label_class
=
None
,
bag_of_words
=
None
,
length
=
100
,
grad_length
=
10000
,
stepsize
=
0.02
,
num_iterations
=
3
,
temperature
=
1.0
,
gm_scale
=
0.9
,
kl_scale
=
0.01
,
top_k
=
10
,
window_length
=
0
,
horizon_length
=
1
,
decay
=
False
,
gamma
=
1.5
,
**
kwargs
):
def
latent_perturb
(
model
,
args
,
context
=
None
,
sample
=
True
,
device
=
'cuda'
):
classifier
,
class_id
=
get_classifier
(
discrim
,
label_class
,
args
.
discrim
,
args
.
label_class
,
device
)
bow_indices
=
[]
if
bag_of_words
:
bow_indices
=
get_bag_of_words_indices
(
bag_of_words
.
split
(
";"
))
if
bag_of_words
and
classifier
:
# if args.discrim == 'clickbait':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt"))
# classifier.eval()
# args.label_class = 1 # clickbaity
#
# elif args.discrim == 'sentiment':
# classifier = ClassificationHead(class_size=5, embed_size=1024).to(device)
# #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt"))
# classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt"))
# classifier.eval()
# if args.label_class < 0:
# raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*')
# #args.label_class = 2 # very pos
# #args.label_class = 3 # very neg
#
# elif args.discrim == 'toxicity':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt"))
# classifier.eval()
# args.label_class = 0 # not toxic
#
# elif args.discrim == 'generic':
# if args.discrim_weights is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_weights need to be specified')
# if args.discrim_meta is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_meta need to be specified')
#
# with open(args.discrim_meta, 'r') as discrim_meta_file:
# meta = json.load(discrim_meta_file)
#
# classifier = ClassificationHead(
# class_size=meta['class_size'],
# embed_size=meta['embed_size'],
# # todo add tokenizer from meta
# ).to(device)
# classifier.load_state_dict(torch.load(args.discrim_weights))
# classifier.eval()
# if args.label_class == -1:
# args.label_class = meta['default_class']
#
# else:
# classifier = None
# Get tokens for the list of positive words
def
list_tokens
(
word_list
):
token_list
=
[
TOKENIZER
.
encode
(
word
,
add_prefix_space
=
True
)
for
word
in
word_list
]
# token_list = []
# for word in word_list:
# token_list.append(TOKENIZER.encode(" " + word))
return
token_list
# good_index = []
# if args.bag_of_words:
# bags_of_words = args.bag_of_words.split(";")
# for wordlist in bags_of_words:
# with open(wordlist, "r") as f:
# words = f.read().strip()
# words = words.split('\n')
# good_index.append(list_tokens(words))
#
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
good_index
=
[]
actual_words
=
None
if
args
.
bag_of_words
:
good_index
=
get_bag_of_words_indices
(
args
.
bag_of_words
.
split
(
";"
))
for
good_list
in
good_index
:
good_list
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
good_list
))
actual_words
=
[(
TOKENIZER
.
decode
(
ww
).
strip
(),
ww
)
for
ww
in
good_list
]
if
args
.
bag_of_words
and
classifier
:
print
(
"Both PPLM-BoW and PPLM-Discrim are on. This is not optimized."
)
loss_type
=
PPLM_BOW_DISCRIM
args
.
loss_type
=
PPLM_BOW_DISCRIM
elif
bag_of_words
:
loss_type
=
PPLM_BOW
elif
args
.
bag_of_words
:
args
.
loss_type
=
PPLM_BOW
print
(
"Using PPLM-BoW"
)
elif
classifier
is
not
None
:
loss_type
=
PPLM_DISCRIM
args
.
loss_type
=
PPLM_DISCRIM
print
(
"Using PPLM-Discrim"
)
else
:
raise
Exception
(
"Specify either --bag_of_words (-B) or --discrim (-D)"
)
unpert_gen_tok_text
,
_
,
_
=
generate_text_pplm
(
model
=
model
,
context
=
context
,
original
,
_
,
_
=
sample_from_hidden
(
model
=
model
,
args
=
args
,
context
=
context
,
device
=
device
,
length
=
length
,
perturb
=
False
)
perturb
=
False
,
good_index
=
good_index
,
classifier
=
classifier
)
torch
.
cuda
.
empty_cache
()
pert
_gen_tok_texts
=
[]
discrim_loss
es
=
[]
loss
es
_in_time
=
[]
pert
urbed_list
=
[]
discrim_loss
_list
=
[]
loss_in_time
_list
=
[]
for
i
in
range
(
num_samples
):
pert
_gen_tok_text
,
discrim_loss
,
loss_in_time
=
generate_text_pplm
(
model
=
model
,
for
i
in
range
(
args
.
num_samples
):
pert
urbed
,
discrim_loss
,
loss_in_time
=
sample_from_hidden
(
model
=
model
,
args
=
args
,
context
=
context
,
device
=
device
,
sample
=
sample
,
perturb
=
True
,
bow_indices
=
bow_indices
,
classifier
=
classifier
,
label_class
=
class_id
,
loss_type
=
loss_type
,
length
=
length
,
grad_length
=
grad_length
,
stepsize
=
stepsize
,
num_iterations
=
num_iterations
,
temperature
=
temperature
,
gm_scale
=
gm_scale
,
kl_scale
=
kl_scale
,
top_k
=
top_k
,
window_length
=
window_length
,
horizon_length
=
horizon_length
,
decay
=
decay
,
gamma
=
gamma
,
)
pert_gen_tok_texts
.
append
(
pert_gen_tok_text
)
good_index
=
good_index
,
classifier
=
classifier
)
perturbed_list
.
append
(
perturbed
)
if
classifier
is
not
None
:
discrim_loss
es
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
loss
es
_in_time
.
append
(
loss_in_time
)
discrim_loss
_list
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
loss_in_time
_list
.
append
(
loss_in_time
)
torch
.
cuda
.
empty_cache
()
return
unpert_gen_tok_text
,
pert_gen_tok_texts
,
discrim_losses
,
losses_in_time
return
original
,
perturbed_list
,
discrim_loss_list
,
loss_in_time_list
,
actual_words
def
generate_text_pplm
(
model
,
context
=
None
,
past
=
None
,
device
=
"cuda"
,
sample
=
True
,
perturb
=
True
,
classifier
=
None
,
label_class
=
None
,
bow_indices
=
None
,
loss_type
=
0
,
length
=
100
,
grad_length
=
10000
,
stepsize
=
0.02
,
num_iterations
=
3
,
temperature
=
1.0
,
gm_scale
=
0.9
,
kl_scale
=
0.01
,
top_k
=
10
,
window_length
=
0
,
horizon_length
=
1
,
decay
=
False
,
gamma
=
1.5
,
):
output_so_far
=
(
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
)
# collect one hot vectors for bags of words
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
)
def
sample_from_hidden
(
model
,
args
,
classifier
,
context
=
None
,
past
=
None
,
device
=
'cuda'
,
sample
=
True
,
perturb
=
True
,
good_index
=
None
):
output
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
grad_norms
=
None
last
=
None
unpert_discrim_loss
=
0
loss_in_time
=
[]
for
i
in
trange
(
length
,
ascii
=
True
):
for
i
in
trange
(
args
.
length
,
ascii
=
True
):
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current_token
# Note that GPT takes 2 inputs: past + current-token
# Therefore, use everything from before current i/p token to generate relevant past
# run model forward to obtain unperturbed
if
past
is
None
and
output_so_far
is
not
None
:
last
=
output_so_far
[:,
-
1
:
]
if
output_so_far
.
shape
[
1
]
>
1
:
_
,
past
,
_
=
model
(
output_so_far
[:,
:
-
1
])
if
past
is
None
and
output
is
not
None
:
prev
=
output
[:,
-
1
:]
# _, past = model(output
[:,
:
-1]
)
# original_probs, true_past = model(output)
# true_hidden = model.hidden_states
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output_so_far
)
unpert_last_hidden
=
unpert_all_hidden
[
-
1
]
# Piero modified model call
_
,
past
,
_
=
model
(
output
[:,
:
-
1
])
original_probs
,
true_past
,
unpert_all_hidden
=
model
(
output
)
true_hidden
=
unpert_all_hidden
[
-
1
]
else
:
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output_so_far
)
unpert_last_hidden
=
unpert_all_hidden
[
-
1
]
# original_probs, true_past = model(output)
# true_hidden = model.hidden_states
# Piero modified model call
original_probs
,
true_past
,
unpert_all_hidden
=
model
(
output
)
true_hidden
=
unpert_all_hidden
[
-
1
]
# Modify the past if necessary
# check if we are abowe grad max length
if
i
>=
grad_length
:
current_stepsize
=
stepsize
*
0
if
i
>=
args
.
grad_length
:
current_stepsize
=
args
.
stepsize
*
0
else
:
current_stepsize
=
stepsize
current_stepsize
=
args
.
stepsize
# modify the past if necessary
if
not
perturb
or
num_iterations
==
0
:
pert_past
=
past
if
not
perturb
or
args
.
num_iterations
==
0
:
perturbed_past
=
past
else
:
accumulated_hidden
=
unpert_last_hidden
[:,
:
-
1
,
:]
# Piero modified model call
# accumulated_hidden = model.hidden_states[:, :-1, :]
accumulated_hidden
=
true_hidden
[:,
:
-
1
,
:]
accumulated_hidden
=
torch
.
sum
(
accumulated_hidden
,
dim
=
1
)
if
past
is
not
None
:
pert_past
,
_
,
grad_norms
,
loss_this_iter
=
perturb_past
(
past
,
perturbed_past
,
_
,
grad_norms
,
loss_per_iter
=
perturb_past
(
past
,
model
,
last
,
unpert_past
=
unpert_past
,
unpert_logits
=
unpert_logits
,
accumulated_hidden
=
accumulated_hidden
,
grad_norms
=
grad_norms
,
prev
,
args
,
good_index
=
good_index
,
stepsize
=
current_stepsize
,
original_probs
=
original_probs
,
true_past
=
true_past
,
accumulated_hidden
=
accumulated_hidden
,
classifier
=
classifier
,
label_class
=
label_class
,
one_hot_bows_vectors
=
one_hot_bows_vectors
,
loss_type
=
loss_type
,
num_iterations
=
num_iterations
,
kl_scale
=
kl_scale
,
window_length
=
window_length
,
horizon_length
=
horizon_length
,
decay
=
decay
,
gamma
=
gamma
,
)
loss_in_time
.
append
(
loss_this_iter
)
else
:
pert_past
=
past
grad_norms
=
grad_norms
)
loss_in_time
.
append
(
loss_per_iter
)
pert_logits
,
past
,
pert_all_hidden
=
model
(
last
,
past
=
pert_past
)
pert_logits
=
pert_logits
[:,
-
1
,
:]
/
temperature
pert_probs
=
F
.
softmax
(
pert_logits
,
dim
=-
1
)
# Piero modified model call
logits
,
past
,
pert_all_hidden
=
model
(
prev
,
past
=
perturbed_past
)
# test_logits = F.softmax(test_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(test_logits, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
# compute the discriminator loss using unperturbed hidden
if
classifier
is
not
None
:
prediction
=
classifier
(
torch
.
mean
(
unpert_last_hidden
,
dim
=
1
))
label
=
torch
.
tensor
([
label_class
],
device
=
"cuda"
,
dtype
=
torch
.
long
)
unpert_discrim_loss
=
torch
.
nn
.
CrossEntropyLoss
()(
prediction
,
label
)
print
(
"unperturbed discrim loss"
,
unpert_discrim_loss
.
data
.
cpu
().
numpy
()
)
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
predicted_sentiment
=
classifier
(
torch
.
mean
(
true_hidden
,
dim
=
1
))
label
=
torch
.
tensor
([
args
.
label_class
],
device
=
'cuda'
,
dtype
=
torch
.
long
)
true_discrim_loss
=
ce_loss
(
predicted_sentiment
,
label
)
print
(
"true discrim loss"
,
true_discrim_loss
.
data
.
cpu
().
numpy
())
else
:
unpert_discrim_loss
=
0
true_discrim_loss
=
0
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits
=
logits
[:,
-
1
,
:]
/
args
.
temperature
# + SmallConst
# Fuse the modified model and original model probabilities
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# Fuse the modified model and original model
if
perturb
:
unpert_probs
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
pert_probs
=
(
pert_probs
**
gm_scale
)
*
(
unpert_probs
**
(
1
-
gm_scale
)
)
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
original_probs
=
F
.
softmax
(
original_probs
[:,
-
1
,
:],
dim
=-
1
)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
pert_probs
=
top_k_filter
(
pert_probs
,
k
=
top_k
,
probs
=
True
)
gm_scale
=
args
.
gm_scale
log_probs
=
((
log_probs
**
gm_scale
)
*
(
original_probs
**
(
1
-
gm_scale
)))
# + SmallConst
# rescale
if
torch
.
sum
(
pert_probs
)
<=
1
:
pert_probs
=
pert_probs
/
torch
.
sum
(
pert_probs
)
log_probs
=
top_k_filter
(
log_probs
,
k
=
args
.
top_k
,
probs
=
True
)
# + SmallConst
if
torch
.
sum
(
log_probs
)
<=
1
:
log_probs
=
log_probs
/
torch
.
sum
(
log_probs
)
else
:
pert_
logits
=
top_k_filter
(
pert_
logits
,
k
=
top_k
)
pert
_probs
=
F
.
softmax
(
pert_
logits
,
dim
=-
1
)
logits
=
top_k_filter
(
logits
,
k
=
args
.
top_k
)
# + SmallConst
log
_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# sample or greedy
if
sample
:
last
=
torch
.
multinomial
(
pert_probs
,
num_samples
=
1
)
# likelywords = torch.topk(log_probs, k=args.top_k, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
# print(likelywords[0].tolist())
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
)
else
:
_
,
last
=
torch
.
topk
(
pert_probs
,
k
=
1
,
dim
=-
1
)
# update context/output_so_far appending the new token
output_so_far
=
(
last
if
output_so_far
is
None
else
torch
.
cat
((
output_so_far
,
last
),
dim
=
1
)
)
print
(
TOKENIZER
.
decode
(
output_so_far
.
tolist
()[
0
]))
_
,
prev
=
torch
.
topk
(
log_probs
,
k
=
1
,
dim
=-
1
)
# if perturb:
# prev = future
output
=
prev
if
output
is
None
else
torch
.
cat
((
output
,
prev
),
dim
=
1
)
# update output
print
(
TOKENIZER
.
decode
(
output
.
tolist
()[
0
]))
return
output
_so_far
,
unpert
_discrim_loss
,
loss_in_time
return
output
,
true
_discrim_loss
,
loss_in_time
def
run_model
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_path"
,
"-M"
,
type
=
str
,
default
=
"gpt2-medium"
,
help
=
"pretrained model name or path to local checkpoint"
,
)
parser
.
add_argument
(
"--bag_of_words"
,
"-B"
,
type
=
str
,
default
=
None
,
help
=
"Bags of words used for PPLM-BoW. Either a BOW id (see list in code) or a filepath. Multiple BoWs separated by ;"
,
)
parser
.
add_argument
(
"--discrim"
,
"-D"
,
type
=
str
,
default
=
None
,
choices
=
(
"clickbait"
,
"sentiment"
,
"toxicity"
),
help
=
"Discriminator to use for loss-type 2"
,
)
parser
.
add_argument
(
"--label_class"
,
type
=
int
,
default
=-
1
,
help
=
"Class label used for the discriminator"
,
)
parser
.
add_argument
(
"--stepsize"
,
type
=
float
,
default
=
0.02
)
parser
.
add_argument
(
'--model_path'
,
'-M'
,
type
=
str
,
default
=
'gpt2-medium'
,
help
=
'pretrained model name or path to local checkpoint'
)
parser
.
add_argument
(
'--bag-of-words'
,
'-B'
,
type
=
str
,
default
=
None
,
help
=
'Bags of words used for PPLM-BoW. Multiple BoWs separated by ;'
)
parser
.
add_argument
(
'--discrim'
,
'-D'
,
type
=
str
,
default
=
None
,
choices
=
(
'clickbait'
,
'sentiment'
,
'toxicity'
,
'generic'
),
help
=
'Discriminator to use for loss-type 2'
)
parser
.
add_argument
(
'--discrim_weights'
,
type
=
str
,
default
=
None
,
help
=
'Weights for the generic discriminator'
)
parser
.
add_argument
(
'--discrim_meta'
,
type
=
str
,
default
=
None
,
help
=
'Meta information for the generic discriminator'
)
parser
.
add_argument
(
'--label_class'
,
type
=
int
,
default
=-
1
,
help
=
'Class label used for the discriminator'
)
parser
.
add_argument
(
'--stepsize'
,
type
=
float
,
default
=
0.02
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--gm_scale"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--kl_scale"
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
"--no_cuda"
,
action
=
"store_true"
,
help
=
"no cuda"
)
parser
.
add_argument
(
"--uncond"
,
action
=
"store_true"
,
help
=
"Generate from end-of-text as prefix"
)
parser
.
add_argument
(
"--cond_text"
,
type
=
str
,
default
=
"The lake"
,
help
=
"Prefix texts to condition on"
)
parser
.
add_argument
(
"--num_iterations"
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
"--grad_length"
,
type
=
int
,
default
=
10000
)
parser
.
add_argument
(
"--num_samples"
,
type
=
int
,
default
=
1
,
help
=
"Number of samples to generate from the modified latents"
,
)
parser
.
add_argument
(
"--horizon_length"
,
type
=
int
,
default
=
1
,
help
=
"Length of future to optimize over"
,
)
parser
.
add_argument
(
"--window_length"
,
type
=
int
,
default
=
0
,
help
=
"Length of past which is being optimized; "
"0 corresponds to infinite window length"
,
)
parser
.
add_argument
(
"--decay"
,
action
=
"store_true"
,
help
=
"whether to decay or not"
)
parser
.
add_argument
(
"--gamma"
,
type
=
float
,
default
=
1.5
)
parser
.
add_argument
(
'--nocuda'
,
action
=
'store_true'
,
help
=
'no cuda'
)
parser
.
add_argument
(
'--uncond'
,
action
=
'store_true'
,
help
=
'Generate from end-of-text as prefix'
)
parser
.
add_argument
(
"--cond_text"
,
type
=
str
,
default
=
'The lake'
,
help
=
'Prefix texts to condition on'
)
parser
.
add_argument
(
'--num_iterations'
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
'--grad_length'
,
type
=
int
,
default
=
10000
)
parser
.
add_argument
(
'--num_samples'
,
type
=
int
,
default
=
1
,
help
=
'Number of samples to generate from the modified latents'
)
parser
.
add_argument
(
'--horizon_length'
,
type
=
int
,
default
=
1
,
help
=
'Length of future to optimize over'
)
# parser.add_argument('--force-token', action='store_true', help='no cuda')
parser
.
add_argument
(
'--window_length'
,
type
=
int
,
default
=
0
,
help
=
'Length of past which is being optimizer; 0 corresponds to infinite window length'
)
parser
.
add_argument
(
'--decay'
,
action
=
'store_true'
,
help
=
'whether to decay or not'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
1.5
)
parser
.
add_argument
(
'--colorama'
,
action
=
'store_true'
,
help
=
'no cuda'
)
args
=
parser
.
parse_args
()
# set Random seed
torch
.
manual_seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
# set the device
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
device
=
'cpu'
if
args
.
nocuda
else
'cuda'
# load pretrained model
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_path
,
output_hidden_states
=
True
...
...
@@ -720,63 +673,82 @@ def run_model():
model
.
to
(
device
)
model
.
eval
()
#
f
reeze GPT-2 weights
#
F
reeze GPT-2 weights
for
param
in
model
.
parameters
():
param
.
requires_grad
=
False
pass
# figure out conditioning text
if
args
.
uncond
:
tokenized_cond_text
=
TOKENIZER
.
encode
(
[
TOKENIZER
.
bos_token
]
)
seq
=
[[
50256
,
50256
]]
else
:
raw_text
=
args
.
cond_text
while
not
raw_text
:
print
(
"
Did you forget to add `--cond
_
text`?
"
)
print
(
'
Did you forget to add `--cond
-
text`?
'
)
raw_text
=
input
(
"Model prompt >>> "
)
tokenized_cond_text
=
TOKENIZER
.
encode
(
TOKENIZER
.
bos_token
+
raw_text
)
seq
=
[[
50256
]
+
TOKENIZER
.
encode
(
raw_text
)
]
print
(
"= Prefix of sentence ="
)
print
(
TOKENIZER
.
decode
(
tokenized_cond_text
))
print
()
collect_gen
=
dict
(
)
current_index
=
0
for
out
in
seq
:
# generate unperturbed and perturbed texts
text
=
TOKENIZER
.
decode
(
out
)
print
(
"="
*
40
+
" Prefix of sentence "
+
"="
*
40
)
print
(
text
)
print
(
"="
*
80
)
# full_text_generation returns:
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
unpert_gen_tok_text
,
pert_gen_tok_texts
,
_
,
_
=
full_text_generation
(
model
=
model
,
context
=
tokenized_cond_text
,
device
=
device
,
**
vars
(
args
)
)
out1
,
out_perturb
,
discrim_loss_list
,
loss_in_time_list
,
actual_words
=
latent_perturb
(
model
=
model
,
args
=
args
,
context
=
out
,
device
=
device
)
# untokenize unperturbed text
unpert_gen_text
=
TOKENIZER
.
decode
(
unpert_gen_tok_text
.
tolist
()[
0
])
text_whole
=
TOKENIZER
.
decode
(
out1
.
tolist
()[
0
])
print
(
"="
*
80
)
print
(
"= Unperturbed generated text ="
)
print
(
unpert_gen_text
)
print
()
generated_texts
=
[]
# iterate through the perturbed texts
for
i
,
pert_gen_tok_text
in
enumerate
(
pert_gen_tok_texts
):
try
:
# untokenize unperturbed text
unpert_gen_text
=
TOKENIZER
.
decode
(
pert_gen_tok_text
.
tolist
()[
0
])
print
(
"= Perturbed generated text {} ="
.
format
(
i
+
1
))
print
(
unpert_gen_text
)
print
()
except
:
pass
print
(
"="
*
40
+
" Whole sentence (Original)"
+
"="
*
40
)
print
(
text_whole
)
print
(
"="
*
80
)
out_perturb_copy
=
out_perturb
for
out_perturb
in
out_perturb_copy
:
# try:
# print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
# text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
# print(text_whole)
# print("=" * 80)
# except:
# pass
# collect_gen[current_index] = [out, out_perturb, out1]
## Save the prefix, perturbed seq, original seq for each index
print
(
"="
*
40
+
" Whole sentence (Perturbed)"
+
"="
*
40
)
keyword_tokens
=
[
aa
[
-
1
][
0
]
for
aa
in
actual_words
]
if
actual_words
else
[]
output_tokens
=
out_perturb
.
tolist
()[
0
]
if
args
.
colorama
:
import
colorama
text_whole
=
''
for
out
in
output_tokens
:
if
out
in
keyword_tokens
:
text_whole
+=
'%s%s%s'
%
(
colorama
.
Fore
.
GREEN
,
TOKENIZER
.
decode
([
out
]),
colorama
.
Style
.
RESET_ALL
)
else
:
text_whole
+=
TOKENIZER
.
decode
([
out
])
else
:
text_whole
=
TOKENIZER
.
decode
(
out_perturb
.
tolist
()[
0
])
print
(
text_whole
)
print
(
"="
*
80
)
collect_gen
[
current_index
]
=
[
out
,
out_perturb
,
out1
]
current_index
=
current_index
+
1
# keep the prefix, perturbed seq, original seq for each index
generated_texts
.
append
(
(
tokenized_cond_text
,
pert_gen_tok_text
,
unpert_gen_tok_text
)
)
return
generated_texts
return
if
__name__
==
"
__main__
"
:
if
__name__
==
'
__main__
'
:
run_model
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment