Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
34a83faa
"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "36ea8784a875bde21c88f84dfb99475b6e8187e8"
Commit
34a83faa
authored
Nov 25, 2019
by
Piero Molino
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Let's make PPLM great again
parent
d5faa74c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
452 additions
and
480 deletions
+452
-480
examples/run_pplm.py
examples/run_pplm.py
+452
-480
No files found.
examples/run_pplm.py
View file @
34a83faa
#! /usr/bin/env python3
# coding=utf-8
# coding=utf-8
# Copyright 2018 The Uber AI Team Authors.
# Copyright 2018 The Uber AI Team Authors.
#
#
...
@@ -37,10 +38,12 @@ from transformers import GPT2Tokenizer
...
@@ -37,10 +38,12 @@ from transformers import GPT2Tokenizer
from
transformers.file_utils
import
cached_path
from
transformers.file_utils
import
cached_path
from
transformers.modeling_gpt2
import
GPT2LMHeadModel
from
transformers.modeling_gpt2
import
GPT2LMHeadModel
PPLM_BOW
=
1
PPLM_BOW
=
1
PPLM_DISCRIM
=
2
PPLM_DISCRIM
=
2
PPLM_BOW_DISCRIM
=
3
PPLM_BOW_DISCRIM
=
3
SMALL_CONST
=
1e-15
SMALL_CONST
=
1e-15
SmallConst
=
1e-15
TOKENIZER
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2-medium"
)
TOKENIZER
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2-medium"
)
BAG_OF_WORDS_ARCHIVE_MAP
=
{
BAG_OF_WORDS_ARCHIVE_MAP
=
{
...
@@ -65,7 +68,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
...
@@ -65,7 +68,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"default_class"
:
1
,
"default_class"
:
1
,
},
},
"sentiment"
:
{
"sentiment"
:
{
"url"
:
"http
s
://s
3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/sentiment
_classifierhead.pt"
,
"url"
:
"http://s
.yosinski.com/SST
_classifier
_
head.pt"
,
"class_size"
:
5
,
"class_size"
:
5
,
"embed_size"
:
1024
,
"embed_size"
:
1024
,
"class_vocab"
:
{
"very_positive"
:
2
,
"very_negative"
:
3
},
"class_vocab"
:
{
"very_positive"
:
2
,
"very_negative"
:
3
},
...
@@ -81,6 +84,30 @@ DISCRIMINATOR_MODELS_PARAMS = {
...
@@ -81,6 +84,30 @@ DISCRIMINATOR_MODELS_PARAMS = {
}
}
def
to_var
(
x
,
requires_grad
=
False
,
volatile
=
False
):
if
torch
.
cuda
.
is_available
():
x
=
x
.
cuda
()
return
Variable
(
x
,
requires_grad
=
requires_grad
,
volatile
=
volatile
)
def
top_k_filter
(
logits
,
k
,
probs
=
False
):
"""
Masks everything but the k top entries as -infinity (1e10).
Used to mask logits such that e^-infinity -> 0 won't contribute to the
sum of the denominator.
"""
if
k
==
0
:
return
logits
else
:
values
=
torch
.
topk
(
logits
,
k
)[
0
]
batch_mins
=
values
[:,
-
1
].
view
(
-
1
,
1
).
expand_as
(
logits
)
if
probs
:
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
0.0
,
logits
)
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
-
1e10
,
logits
)
class
ClassificationHead
(
torch
.
nn
.
Module
):
class
ClassificationHead
(
torch
.
nn
.
Module
):
""" Classification Head for the transformer """
""" Classification Head for the transformer """
...
@@ -99,234 +126,175 @@ class ClassificationHead(torch.nn.Module):
...
@@ -99,234 +126,175 @@ class ClassificationHead(torch.nn.Module):
return
logits
return
logits
def
to_var
(
x
,
requires_grad
=
False
,
volatile
=
False
):
def
perturb_past
(
past
,
model
,
prev
,
args
,
classifier
,
good_index
=
None
,
if
torch
.
cuda
.
is_available
():
stepsize
=
0.01
,
vocab_size
=
50257
,
x
=
x
.
cuda
()
original_probs
=
None
,
accumulated_hidden
=
None
,
true_past
=
None
,
return
Variable
(
x
,
requires_grad
=
requires_grad
,
volatile
=
volatile
)
grad_norms
=
None
):
window_length
=
args
.
window_length
gm_scale
,
kl_scale
=
args
.
gm_scale
,
args
.
kl_scale
def
top_k_filter
(
logits
,
k
,
probs
=
False
):
one_hot_vectors
=
[]
"""
for
good_list
in
good_index
:
Masks everything but the k top entries as -infinity (1e10).
good_list
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
good_list
))
Used to mask logits such that e^-infinity -> 0 won't contribute to the
good_list
=
torch
.
tensor
(
good_list
).
cuda
()
sum of the denominator.
num_good
=
good_list
.
shape
[
0
]
"""
one_hot_good
=
torch
.
zeros
(
num_good
,
vocab_size
).
cuda
()
if
k
<=
0
:
one_hot_good
.
scatter_
(
1
,
good_list
,
1
)
return
logits
one_hot_vectors
.
append
(
one_hot_good
)
else
:
# Generate inital perturbed past
values
=
torch
.
topk
(
logits
,
k
)[
0
]
past_perturb_orig
=
[
batch_mins
=
values
[:,
-
1
].
view
(
-
1
,
1
).
expand_as
(
logits
)
(
np
.
random
.
uniform
(
0.0
,
0.0
,
p
.
shape
).
astype
(
'float32'
))
for
p
in
past
]
if
probs
:
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
0.0
,
logits
)
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
-
1e10
,
logits
)
def
perturb_past
(
past
,
model
,
last
,
unpert_past
=
None
,
unpert_logits
=
None
,
accumulated_hidden
=
None
,
grad_norms
=
None
,
stepsize
=
0.01
,
classifier
=
None
,
label_class
=
None
,
one_hot_bows_vectors
=
None
,
loss_type
=
0
,
num_iterations
=
3
,
kl_scale
=
0.01
,
window_length
=
0
,
horizon_length
=
1
,
decay
=
False
,
gamma
=
1.5
,
):
# initializie perturbation accumulator
grad_accumulator
=
[
(
np
.
zeros
(
p
.
shape
).
astype
(
"float32"
))
for
p
in
past
]
if
accumulated_hidden
is
None
:
if
accumulated_hidden
is
None
:
accumulated_hidden
=
0
accumulated_hidden
=
0
if
decay
:
if
args
.
decay
:
decay_mask
=
torch
.
arange
(
decay_mask
=
torch
.
arange
(
0.
,
1.0
+
SmallConst
,
1.0
/
(
window_length
))[
0.0
,
1
:]
1.0
+
SMALL_CONST
,
1.0
/
(
window_length
)
)[
1
:]
else
:
else
:
decay_mask
=
1.0
decay_mask
=
1.0
# TODO fix this comment (SUMANTH)
# Generate a mask is gradient perturbated is based on a past window
# generate a mask if perturbated gradient is based on a past window
_
,
_
,
_
,
current_length
,
_
=
past
[
0
].
shape
_
,
_
,
_
,
curr_length
,
_
=
past
[
0
].
shape
if
curr_length
>
window_length
and
window_length
>
0
:
if
current_length
>
window_length
and
window_length
>
0
:
ones_key_val_shape
=
(
ones_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
(
tuple
(
past
[
0
].
shape
[:
-
2
])
[
window_length
])
+
tuple
(
+
tuple
([
window_length
])
past
[
0
].
shape
[
-
1
:])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
zeros_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
(
[
current_length
-
window_length
])
+
tuple
(
zeros_key_val_shape
=
(
past
[
0
].
shape
[
-
1
:])
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
([
curr_length
-
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
ones_mask
=
torch
.
ones
(
ones_key_val_shape
)
ones_mask
=
torch
.
ones
(
ones_key_val_shape
)
ones_mask
=
decay_mask
*
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
ones_mask
=
decay_mask
*
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
ones_mask
=
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
ones_mask
=
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
window_mask
=
torch
.
cat
(
window_mask
=
torch
.
cat
((
ones_mask
,
torch
.
zeros
(
zeros_key_val_shape
)),
(
ones_mask
,
torch
.
zeros
(
zeros_key_val_shape
)),
dim
=-
2
).
cuda
()
dim
=-
2
).
cuda
()
else
:
else
:
window_mask
=
torch
.
ones_like
(
past
[
0
]).
cuda
()
window_mask
=
torch
.
ones_like
(
past
[
0
]).
cuda
()
# accumulate perturbations for num_iterations
loss_per_iter
=
[]
loss_per_iter
=
[]
for
i
in
range
(
num_iterations
):
for
i
in
range
(
args
.
num_iterations
):
print
(
"Iteration "
,
i
+
1
)
print
(
"Iteration "
,
i
+
1
)
past_perturb
=
[
torch
.
from_numpy
(
p_
)
for
p_
in
past_perturb_orig
]
past_perturb
=
[
to_var
(
p_
,
requires_grad
=
True
)
for
p_
in
past_perturb
]
curr_perturbation
=
[
perturbed_past
=
list
(
map
(
add
,
past
,
past_perturb
))
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
for
p_
in
grad_accumulator
]
# Compute hidden using perturbed past
_
,
_
,
_
,
current_length
,
_
=
past_perturb
[
0
].
shape
curr_pert_past
=
list
(
map
(
add
,
past
,
curr_perturbation
))
all_logits
,
_
,
all_hidden
=
model
(
last
,
past
=
curr_pert_past
)
hidden
=
all_hidden
[
-
1
]
accumulated_hidden
+=
torch
.
sum
(
hidden
,
dim
=
1
).
detach
()
logits
=
all_logits
[:,
-
1
,
:]
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# compute loss
# _, future_past = model(prev, past=perturbed_past)
bow_loss
=
0.0
# hidden = model.hidden_states
discrim_loss
=
0.0
kl_loss
=
0.0
if
loss_type
==
PPLM_BOW
or
loss_type
==
PPLM_BOW_DISCRIM
:
# Piero modified model call
for
one_hot_bow
in
one_hot_bows_vectors
:
logits
,
_
,
all_hidden
=
model
(
prev
,
past
=
perturbed_past
)
bow_logits
=
torch
.
mm
(
probs
,
torch
.
t
(
one_hot_bow
))
hidden
=
all_hidden
[
-
1
]
bow_loss
+=
-
torch
.
log
(
torch
.
sum
(
bow_logits
))
new_accumulated_hidden
=
accumulated_hidden
+
torch
.
sum
(
hidden
,
print
(
" pplm_bow_loss:"
,
bow_loss
.
data
.
cpu
().
numpy
())
dim
=
1
).
detach
()
if
loss_type
==
PPLM_DISCRIM
or
loss_type
==
PPLM_BOW_DISCRIM
:
# TODO: Check the layer-norm consistency of this with trained discriminator
logits
=
logits
[:,
-
1
,
:]
probabs
=
F
.
softmax
(
logits
,
dim
=-
1
)
loss
=
0.0
loss_list
=
[]
if
args
.
loss_type
==
1
or
args
.
loss_type
==
3
:
for
one_hot_good
in
one_hot_vectors
:
good_logits
=
torch
.
mm
(
probabs
,
torch
.
t
(
one_hot_good
))
loss_word
=
good_logits
loss_word
=
torch
.
sum
(
loss_word
)
loss_word
=
-
torch
.
log
(
loss_word
)
# loss_word = torch.sum(loss_word) /torch.sum(one_hot_good)
loss
+=
loss_word
loss_list
.
append
(
loss_word
)
print
(
" pplm_bow_loss:"
,
loss
.
data
.
cpu
().
numpy
())
if
args
.
loss_type
==
2
or
args
.
loss_type
==
3
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
# TODO all there are for (SUMANTH)
new_true_past
=
true_past
# TODO why we need to do this assignment and not just using unpert_past?
for
i
in
range
(
args
.
horizon_length
):
curr_unpert_past
=
unpert_past
future_probabs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# Get softmax
# Get the model's token embeddings in order to compute our own embeds from curr_probs:
future_probabs
=
torch
.
unsqueeze
(
future_probabs
,
dim
=
1
)
wte
=
model
.
resize_token_embeddings
()
# TODO i is never used, why do we need to do this i times instead multiplying
# _, new_true_past = model(future_probabs, past=new_true_past)
# torch.sum(unpert_hidden, dim=1) * horizon_length?
# future_hidden = model.hidden_states # Get expected hidden states
for
i
in
range
(
horizon_length
):
# TODO the next two lines can be done only one time, and why not using probs instead as they do not change at each iteration?
# Piero modified model call
curr_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# get softmax
wte
=
model
.
resize_token_embeddings
()
curr_probs
=
torch
.
unsqueeze
(
curr_probs
,
dim
=
1
)
inputs_embeds
=
torch
.
matmul
(
future_probabs
,
wte
.
weight
.
data
)
inputs_embeds
=
torch
.
matmul
(
curr_probs
,
wte
.
weight
.
data
)
_
,
new_true_past
,
future_hidden
=
model
(
_
,
curr_unpert_past
,
curr_all_hidden
=
model
(
past
=
new_true_past
,
past
=
curr_unpert_past
,
inputs_embeds
=
inputs_embeds
inputs_embeds
=
inputs_embeds
)
)
# get expected hidden states
future_hidden
=
future_hidden
[
-
1
]
unpert_hidden
=
curr_all_hidden
[
-
1
]
accumulated_hidden
+=
torch
.
sum
(
unpert_hidden
,
dim
=
1
).
detach
()
new_accumulated_hidden
=
new_accumulated_hidden
+
torch
.
sum
(
future_hidden
,
dim
=
1
)
prediction
=
classifier
(
predicted_sentiment
=
classifier
(
new_accumulated_hidden
/
(
accumulated_hidden
/
(
curr_length
+
1
+
horizon_length
)
current_length
+
1
+
args
.
horizon_length
))
)
label
=
torch
.
tensor
([
label_class
],
device
=
"cuda"
,
dtype
=
torch
.
long
)
label
=
torch
.
tensor
([
args
.
label_class
],
device
=
'cuda'
,
discrim_loss
+=
ce_loss
(
prediction
,
label
)
dtype
=
torch
.
long
)
discrim_loss
=
ce_loss
(
predicted_sentiment
,
label
)
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
loss
+=
discrim_loss
loss_list
.
append
(
discrim_loss
)
if
kl_scale
>=
0.0
:
kl_loss
=
0.0
unpert_probs
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
if
kl_scale
>
0.0
:
unpert_probs
=
(
p
=
(
F
.
softmax
(
original_probs
[:,
-
1
,
:],
dim
=-
1
))
unpert_probs
+
SMALL_CONST
*
p
=
p
+
SmallConst
*
(
p
<=
SmallConst
).
type
(
(
unpert_probs
<=
SMALL_CONST
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
torch
.
FloatTensor
correction
=
SmallConst
*
(
probabs
<=
SmallConst
).
type
(
).
cuda
().
detach
()
torch
.
FloatTensor
).
cuda
().
detach
()
)
corrected_probabs
=
probabs
+
correction
.
detach
()
correction
=
SMALL_CONST
*
(
probs
<=
SMALL_CONST
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
corrected_probs
=
probs
+
correction
.
detach
()
kl_loss
=
kl_scale
*
(
kl_loss
=
kl_scale
*
(
(
corrected_probs
*
(
corrected_probs
/
unpert_probs
).
log
()).
sum
()
(
corrected_probabs
*
(
corrected_probabs
/
p
).
log
()).
sum
())
)
print
(
' kl_loss'
,
(
kl_loss
).
data
.
cpu
().
numpy
())
print
(
' kl_loss'
,
(
kl_loss
).
data
.
cpu
().
numpy
())
loss
+=
kl_loss
# + discrim_loss
loss
=
bow_loss
+
discrim_loss
+
kl_loss
loss_per_iter
.
append
(
loss
.
data
.
cpu
().
numpy
())
loss_per_iter
.
append
(
loss
.
data
.
cpu
().
numpy
())
print
(
' pplm_loss'
,
(
loss
-
kl_loss
).
data
.
cpu
().
numpy
())
print
(
' pplm_loss'
,
(
loss
-
kl_loss
).
data
.
cpu
().
numpy
())
# compute gradients
loss
.
backward
()
loss
.
backward
()
if
grad_norms
is
not
None
and
args
.
loss_type
==
1
:
# calculate gradient norms
if
grad_norms
is
not
None
and
loss_type
==
PPLM_BOW
:
grad_norms
=
[
grad_norms
=
[
torch
.
max
(
grad_norms
[
index
],
torch
.
norm
(
p_
.
grad
*
window_mask
))
torch
.
max
(
grad_norms
[
index
],
torch
.
norm
(
p_
.
grad
*
window_mask
))
for
index
,
p_
in
enumerate
(
curr_perturbation
)
for
index
,
p_
in
]
enumerate
(
past_perturb
)
]
else
:
else
:
grad_norms
=
[
grad_norms
=
[(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
SmallConst
)
for
(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
SMALL_CONST
)
index
,
p_
in
enumerate
(
past_perturb
)]
for
index
,
p_
in
enumerate
(
curr_perturbation
)
]
# normalize gradients
grad
=
[
grad
=
[
-
stepsize
-
stepsize
*
(
p_
.
grad
*
window_mask
/
grad_norms
[
*
(
p_
.
grad
*
window_mask
/
grad_norms
[
index
]
**
args
.
gamma
).
data
.
cpu
().
numpy
()
index
]
**
gamma
).
data
.
cpu
().
numpy
()
for
index
,
p_
in
enumerate
(
past_perturb
)]
for
index
,
p_
in
enumerate
(
curr_perturbation
)
past_perturb_orig
=
list
(
map
(
add
,
grad
,
past_perturb_orig
))
]
# accumulate gradients
grad_accumulator
=
list
(
map
(
add
,
grad
,
grad_accumulator
))
# reset gradients, just to make sure
for
p_
in
past_perturb
:
for
p_
in
curr_perturbation
:
p_
.
grad
.
data
.
zero_
()
p_
.
grad
.
data
.
zero_
()
# removing past from the graph
new_past
=
[]
new_past
=
[]
for
p_
in
past
:
for
p
in
past
:
new_past
.
append
(
p_
.
detach
())
new_past
.
append
(
p
.
detach
())
past
=
new_past
past
=
new_past
# apply the accumulated perturbations to the past
past_perturb
=
[
torch
.
from_numpy
(
p_
)
for
p_
in
past_perturb_orig
]
grad_accumulator
=
[
past_perturb
=
[
to_var
(
p_
,
requires_grad
=
True
)
for
p_
in
past_perturb
]
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
perturbed_past
=
list
(
map
(
add
,
past
,
past_perturb
))
for
p_
in
grad_accumulator
]
pert_past
=
list
(
map
(
add
,
past
,
grad_accumulator
))
return
pert_past
,
accumulated_hidden
,
grad_norms
,
loss_per_iter
return
pert
urbed
_past
,
new_
accumulated_hidden
,
grad_norms
,
loss_per_iter
def
get_classifier
(
def
get_classifier
(
name
:
Optional
[
str
],
label_class
:
Union
[
str
,
int
],
device
:
Union
[
str
,
torch
.
device
]
name
:
Optional
[
str
],
label_class
:
Union
[
str
,
int
],
device
:
Union
[
str
,
torch
.
device
]
)
->
Tuple
[
Optional
[
ClassificationHead
],
Optional
[
int
]]:
)
->
Tuple
[
Optional
[
ClassificationHead
],
Optional
[
int
]]:
if
name
is
None
:
if
name
is
None
:
return
None
,
None
return
None
,
None
...
@@ -337,7 +305,8 @@ def get_classifier(
...
@@ -337,7 +305,8 @@ def get_classifier(
embed_size
=
params
[
'embed_size'
]
embed_size
=
params
[
'embed_size'
]
).
to
(
device
)
).
to
(
device
)
resolved_archive_file
=
cached_path
(
params
[
"url"
])
resolved_archive_file
=
cached_path
(
params
[
"url"
])
classifier
.
load_state_dict
(
torch
.
load
(
resolved_archive_file
,
map_location
=
device
))
classifier
.
load_state_dict
(
torch
.
load
(
resolved_archive_file
,
map_location
=
device
))
classifier
.
eval
()
classifier
.
eval
()
if
isinstance
(
label_class
,
str
):
if
isinstance
(
label_class
,
str
):
...
@@ -364,7 +333,8 @@ def get_classifier(
...
@@ -364,7 +333,8 @@ def get_classifier(
return
classifier
,
label_id
return
classifier
,
label_id
def
get_bag_of_words_indices
(
bag_of_words_ids_or_paths
:
List
[
str
])
->
List
[
List
[
List
[
int
]]]:
def
get_bag_of_words_indices
(
bag_of_words_ids_or_paths
:
List
[
str
])
->
List
[
List
[
List
[
int
]]]:
bow_indices
=
[]
bow_indices
=
[]
for
id_or_path
in
bag_of_words_ids_or_paths
:
for
id_or_path
in
bag_of_words_ids_or_paths
:
if
id_or_path
in
BAG_OF_WORDS_ARCHIVE_MAP
:
if
id_or_path
in
BAG_OF_WORDS_ARCHIVE_MAP
:
...
@@ -372,8 +342,10 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[List[
...
@@ -372,8 +342,10 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[List[
else
:
else
:
filepath
=
id_or_path
filepath
=
id_or_path
with
open
(
filepath
,
"r"
)
as
f
:
with
open
(
filepath
,
"r"
)
as
f
:
words
=
f
.
read
().
split
(
"
\n
"
)
words
=
f
.
read
().
strip
().
split
(
"
\n
"
)
bow_indices
.
append
([
TOKENIZER
.
encode
(
word
,
add_prefix_space
=
True
)
for
word
in
words
])
bow_indices
.
append
(
[
TOKENIZER
.
encode
(
word
.
strip
(),
add_prefix_space
=
True
)
for
word
in
words
])
return
bow_indices
return
bow_indices
...
@@ -392,327 +364,308 @@ def build_bows_one_hot_vectors(bow_indices):
...
@@ -392,327 +364,308 @@ def build_bows_one_hot_vectors(bow_indices):
return
one_hot_bows_vectors
return
one_hot_bows_vectors
def
full_text_generation
(
def
latent_perturb
(
model
,
args
,
context
=
None
,
sample
=
True
,
device
=
'cuda'
):
model
,
context
=
None
,
num_samples
=
1
,
device
=
"cuda"
,
sample
=
True
,
discrim
=
None
,
label_class
=
None
,
bag_of_words
=
None
,
length
=
100
,
grad_length
=
10000
,
stepsize
=
0.02
,
num_iterations
=
3
,
temperature
=
1.0
,
gm_scale
=
0.9
,
kl_scale
=
0.01
,
top_k
=
10
,
window_length
=
0
,
horizon_length
=
1
,
decay
=
False
,
gamma
=
1.5
,
**
kwargs
):
classifier
,
class_id
=
get_classifier
(
classifier
,
class_id
=
get_classifier
(
discrim
,
args
.
discrim
,
label_class
,
args
.
label_class
,
device
device
)
)
bow_indices
=
[]
# if args.discrim == 'clickbait':
if
bag_of_words
:
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
bow_indices
=
get_bag_of_words_indices
(
bag_of_words
.
split
(
";"
))
# classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt"))
# classifier.eval()
if
bag_of_words
and
classifier
:
# args.label_class = 1 # clickbaity
#
# elif args.discrim == 'sentiment':
# classifier = ClassificationHead(class_size=5, embed_size=1024).to(device)
# #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt"))
# classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt"))
# classifier.eval()
# if args.label_class < 0:
# raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*')
# #args.label_class = 2 # very pos
# #args.label_class = 3 # very neg
#
# elif args.discrim == 'toxicity':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt"))
# classifier.eval()
# args.label_class = 0 # not toxic
#
# elif args.discrim == 'generic':
# if args.discrim_weights is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_weights need to be specified')
# if args.discrim_meta is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_meta need to be specified')
#
# with open(args.discrim_meta, 'r') as discrim_meta_file:
# meta = json.load(discrim_meta_file)
#
# classifier = ClassificationHead(
# class_size=meta['class_size'],
# embed_size=meta['embed_size'],
# # todo add tokenizer from meta
# ).to(device)
# classifier.load_state_dict(torch.load(args.discrim_weights))
# classifier.eval()
# if args.label_class == -1:
# args.label_class = meta['default_class']
#
# else:
# classifier = None
# Get tokens for the list of positive words
def
list_tokens
(
word_list
):
token_list
=
[
TOKENIZER
.
encode
(
word
,
add_prefix_space
=
True
)
for
word
in
word_list
]
# token_list = []
# for word in word_list:
# token_list.append(TOKENIZER.encode(" " + word))
return
token_list
# good_index = []
# if args.bag_of_words:
# bags_of_words = args.bag_of_words.split(";")
# for wordlist in bags_of_words:
# with open(wordlist, "r") as f:
# words = f.read().strip()
# words = words.split('\n')
# good_index.append(list_tokens(words))
#
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
good_index
=
[]
actual_words
=
None
if
args
.
bag_of_words
:
good_index
=
get_bag_of_words_indices
(
args
.
bag_of_words
.
split
(
";"
))
for
good_list
in
good_index
:
good_list
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
good_list
))
actual_words
=
[(
TOKENIZER
.
decode
(
ww
).
strip
(),
ww
)
for
ww
in
good_list
]
if
args
.
bag_of_words
and
classifier
:
print
(
"Both PPLM-BoW and PPLM-Discrim are on. This is not optimized."
)
print
(
"Both PPLM-BoW and PPLM-Discrim are on. This is not optimized."
)
loss_type
=
PPLM_BOW_DISCRIM
args
.
loss_type
=
PPLM_BOW_DISCRIM
elif
bag_of_words
:
elif
args
.
bag_of_words
:
loss_type
=
PPLM_BOW
args
.
loss_type
=
PPLM_BOW
print
(
"Using PPLM-BoW"
)
print
(
"Using PPLM-BoW"
)
elif
classifier
is
not
None
:
elif
classifier
is
not
None
:
loss_type
=
PPLM_DISCRIM
args
.
loss_type
=
PPLM_DISCRIM
print
(
"Using PPLM-Discrim"
)
print
(
"Using PPLM-Discrim"
)
else
:
else
:
raise
Exception
(
"Specify either --bag_of_words (-B) or --discrim (-D)"
)
raise
Exception
(
"Specify either --bag_of_words (-B) or --discrim (-D)"
)
unpert_gen_tok_text
,
_
,
_
=
generate_text_pplm
(
original
,
_
,
_
=
sample_from_hidden
(
model
=
model
,
args
=
args
,
context
=
context
,
model
=
model
,
device
=
device
,
context
=
context
,
perturb
=
False
,
good_index
=
good_index
,
device
=
device
,
classifier
=
classifier
)
length
=
length
,
perturb
=
False
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
pert_gen_tok_texts
=
[]
perturbed_list
=
[]
discrim_losses
=
[]
discrim_loss_list
=
[]
losses_in_time
=
[]
loss_in_time_list
=
[]
for
i
in
range
(
num_samples
):
for
i
in
range
(
args
.
num_samples
):
pert_gen_tok_text
,
discrim_loss
,
loss_in_time
=
generate_text_pplm
(
perturbed
,
discrim_loss
,
loss_in_time
=
sample_from_hidden
(
model
=
model
,
model
=
model
,
args
=
args
,
context
=
context
,
context
=
context
,
device
=
device
,
device
=
device
,
sample
=
sample
,
perturb
=
True
,
perturb
=
True
,
good_index
=
good_index
,
bow_indices
=
bow_indices
,
classifier
=
classifier
)
classifier
=
classifier
,
perturbed_list
.
append
(
perturbed
)
label_class
=
class_id
,
loss_type
=
loss_type
,
length
=
length
,
grad_length
=
grad_length
,
stepsize
=
stepsize
,
num_iterations
=
num_iterations
,
temperature
=
temperature
,
gm_scale
=
gm_scale
,
kl_scale
=
kl_scale
,
top_k
=
top_k
,
window_length
=
window_length
,
horizon_length
=
horizon_length
,
decay
=
decay
,
gamma
=
gamma
,
)
pert_gen_tok_texts
.
append
(
pert_gen_tok_text
)
if
classifier
is
not
None
:
if
classifier
is
not
None
:
discrim_loss
es
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
discrim_loss
_list
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
loss
es
_in_time
.
append
(
loss_in_time
)
loss_in_time
_list
.
append
(
loss_in_time
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
unpert_gen_tok_text
,
pert_gen_tok_texts
,
discrim_losses
,
losses_in_time
return
original
,
perturbed_list
,
discrim_loss_list
,
loss_in_time_list
,
actual_words
def
generate_text_pplm
(
model
,
context
=
None
,
past
=
None
,
device
=
"cuda"
,
sample
=
True
,
perturb
=
True
,
classifier
=
None
,
label_class
=
None
,
bow_indices
=
None
,
loss_type
=
0
,
length
=
100
,
grad_length
=
10000
,
stepsize
=
0.02
,
num_iterations
=
3
,
temperature
=
1.0
,
gm_scale
=
0.9
,
kl_scale
=
0.01
,
top_k
=
10
,
window_length
=
0
,
horizon_length
=
1
,
decay
=
False
,
gamma
=
1.5
,
):
output_so_far
=
(
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
)
# collect one hot vectors for bags of words
def
sample_from_hidden
(
model
,
args
,
classifier
,
context
=
None
,
past
=
None
,
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
)
device
=
'cuda'
,
sample
=
True
,
perturb
=
True
,
good_index
=
None
):
output
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
grad_norms
=
None
grad_norms
=
None
last
=
None
unpert_discrim_loss
=
0
loss_in_time
=
[]
loss_in_time
=
[]
for
i
in
trange
(
length
,
ascii
=
True
):
for
i
in
trange
(
args
.
length
,
ascii
=
True
):
# Get past/probs for current output, except for last word
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current_token
# Note that GPT takes 2 inputs: past + current-token
# Therefore, use everything from before current i/p token to generate relevant past
# run model forward to obtain unperturbed
if
past
is
None
and
output
is
not
None
:
if
past
is
None
and
output_so_far
is
not
None
:
prev
=
output
[:,
-
1
:]
last
=
output_so_far
[:,
-
1
:
]
# _, past = model(output
[:,
:
-1]
)
if
output_so_far
.
shape
[
1
]
>
1
:
# original_probs, true_past = model(output)
_
,
past
,
_
=
model
(
output_so_far
[:,
:
-
1
])
# true_hidden = model.hidden_states
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output_so_far
)
# Piero modified model call
unpert_last_hidden
=
unpert_all_hidden
[
-
1
]
_
,
past
,
_
=
model
(
output
[:,
:
-
1
])
original_probs
,
true_past
,
unpert_all_hidden
=
model
(
output
)
true_hidden
=
unpert_all_hidden
[
-
1
]
else
:
else
:
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output
_so_far
)
# original_probs, true_past
= model(output)
unpert_last_hidden
=
unpert_all_hidden
[
-
1
]
# true_hidden = model.hidden_states
# check if we are abowe grad max length
# Piero modified model call
if
i
>=
grad_length
:
original_probs
,
true_past
,
unpert_all_hidden
=
model
(
output
)
current_stepsize
=
stepsize
*
0
true_hidden
=
unpert_all_hidden
[
-
1
]
# Modify the past if necessary
if
i
>=
args
.
grad_length
:
current_stepsize
=
args
.
stepsize
*
0
else
:
else
:
current_stepsize
=
stepsize
current_stepsize
=
args
.
stepsize
# modify the past if necessary
if
not
perturb
or
args
.
num_iterations
==
0
:
if
not
perturb
or
num_iterations
==
0
:
perturbed_past
=
past
pert_past
=
past
else
:
else
:
accumulated_hidden
=
unpert_last_hidden
[:,
:
-
1
,
:]
# Piero modified model call
# accumulated_hidden = model.hidden_states[:, :-1, :]
accumulated_hidden
=
true_hidden
[:,
:
-
1
,
:]
accumulated_hidden
=
torch
.
sum
(
accumulated_hidden
,
dim
=
1
)
accumulated_hidden
=
torch
.
sum
(
accumulated_hidden
,
dim
=
1
)
if
past
is
not
None
:
perturbed_past
,
_
,
grad_norms
,
loss_per_iter
=
perturb_past
(
past
,
pert_past
,
_
,
grad_norms
,
loss_this_iter
=
perturb_past
(
model
,
past
,
prev
,
model
,
args
,
last
,
good_index
=
good_index
,
unpert_past
=
unpert_past
,
stepsize
=
current_stepsize
,
unpert_logits
=
unpert_logits
,
original_probs
=
original_probs
,
accumulated_hidden
=
accumulated_hidden
,
true_past
=
true_past
,
grad_norms
=
grad_norms
,
accumulated_hidden
=
accumulated_hidden
,
stepsize
=
current_stepsize
,
classifier
=
classifier
,
classifier
=
classifier
,
grad_norms
=
grad_norms
)
label_class
=
label_class
,
loss_in_time
.
append
(
loss_per_iter
)
one_hot_bows_vectors
=
one_hot_bows_vectors
,
loss_type
=
loss_type
,
# Piero modified model call
num_iterations
=
num_iterations
,
logits
,
past
,
pert_all_hidden
=
model
(
prev
,
past
=
perturbed_past
)
kl_scale
=
kl_scale
,
# test_logits = F.softmax(test_logits[:, -1, :], dim=-1)
window_length
=
window_length
,
# likelywords = torch.topk(test_logits, k=10, dim=-1)
horizon_length
=
horizon_length
,
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
decay
=
decay
,
gamma
=
gamma
,
)
loss_in_time
.
append
(
loss_this_iter
)
else
:
pert_past
=
past
pert_logits
,
past
,
pert_all_hidden
=
model
(
last
,
past
=
pert_past
)
pert_logits
=
pert_logits
[:,
-
1
,
:]
/
temperature
pert_probs
=
F
.
softmax
(
pert_logits
,
dim
=-
1
)
# compute the discriminator loss using unperturbed hidden
if
classifier
is
not
None
:
if
classifier
is
not
None
:
prediction
=
classifier
(
torch
.
mean
(
unpert_last_hidden
,
dim
=
1
))
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
label
=
torch
.
tensor
([
label_class
],
device
=
"cuda"
,
dtype
=
torch
.
long
)
predicted_sentiment
=
classifier
(
torch
.
mean
(
true_hidden
,
dim
=
1
))
unpert_discrim_loss
=
torch
.
nn
.
CrossEntropyLoss
()(
prediction
,
label
)
label
=
torch
.
tensor
([
args
.
label_class
],
device
=
'cuda'
,
print
(
dtype
=
torch
.
long
)
"unperturbed discrim loss"
,
true_discrim_loss
=
ce_loss
(
predicted_sentiment
,
label
)
unpert_discrim_loss
.
data
.
cpu
().
numpy
()
print
(
"true discrim loss"
,
true_discrim_loss
.
data
.
cpu
().
numpy
())
)
else
:
else
:
unpert_discrim_loss
=
0
true_discrim_loss
=
0
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits
=
logits
[:,
-
1
,
:]
/
args
.
temperature
# + SmallConst
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
# Fuse the modified model and original model probabilities
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# Fuse the modified model and original model
if
perturb
:
if
perturb
:
unpert_probs
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
pert_probs
=
(
pert_probs
**
gm_scale
)
*
(
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
unpert_probs
**
(
1
-
gm_scale
)
original_probs
=
F
.
softmax
(
original_probs
[:,
-
1
,
:],
dim
=-
1
)
)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
gm_scale
=
args
.
gm_scale
log_probs
=
((
log_probs
**
gm_scale
)
*
(
original_probs
**
(
1
-
gm_scale
)))
# + SmallConst
pert_probs
=
top_k_filter
(
pert_probs
,
k
=
top_k
,
probs
=
True
)
log_probs
=
top_k_filter
(
log_probs
,
k
=
args
.
top_k
,
probs
=
True
)
# + SmallConst
# rescale
if
torch
.
sum
(
log_probs
)
<=
1
:
if
torch
.
sum
(
pert_probs
)
<=
1
:
log_probs
=
log_probs
/
torch
.
sum
(
log_probs
)
pert_probs
=
pert_probs
/
torch
.
sum
(
pert_probs
)
else
:
else
:
pert_
logits
=
top_k_filter
(
pert_
logits
,
k
=
top_k
)
logits
=
top_k_filter
(
logits
,
k
=
args
.
top_k
)
# + SmallConst
pert
_probs
=
F
.
softmax
(
pert_
logits
,
dim
=-
1
)
log
_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# sample or greedy
if
sample
:
if
sample
:
last
=
torch
.
multinomial
(
pert_probs
,
num_samples
=
1
)
# likelywords = torch.topk(log_probs, k=args.top_k, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
# print(likelywords[0].tolist())
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
)
else
:
else
:
_
,
last
=
torch
.
topk
(
pert_probs
,
k
=
1
,
dim
=-
1
)
_
,
prev
=
torch
.
topk
(
log_probs
,
k
=
1
,
dim
=-
1
)
# if perturb:
# prev = future
output
=
prev
if
output
is
None
else
torch
.
cat
((
output
,
prev
),
dim
=
1
)
# update output
print
(
TOKENIZER
.
decode
(
output
.
tolist
()[
0
]))
# update context/output_so_far appending the new token
return
output
,
true_discrim_loss
,
loss_in_time
output_so_far
=
(
last
if
output_so_far
is
None
else
torch
.
cat
((
output_so_far
,
last
),
dim
=
1
)
)
print
(
TOKENIZER
.
decode
(
output_so_far
.
tolist
()[
0
]))
return
output_so_far
,
unpert_discrim_loss
,
loss_in_time
def
run_model
():
def
run_model
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
'--model_path'
,
'-M'
,
type
=
str
,
default
=
'gpt2-medium'
,
"--model_path"
,
help
=
'pretrained model name or path to local checkpoint'
)
"-M"
,
parser
.
add_argument
(
'--bag-of-words'
,
'-B'
,
type
=
str
,
default
=
None
,
type
=
str
,
help
=
'Bags of words used for PPLM-BoW. Multiple BoWs separated by ;'
)
default
=
"gpt2-medium"
,
parser
.
add_argument
(
'--discrim'
,
'-D'
,
type
=
str
,
default
=
None
,
help
=
"pretrained model name or path to local checkpoint"
,
choices
=
(
)
'clickbait'
,
'sentiment'
,
'toxicity'
,
'generic'
),
parser
.
add_argument
(
help
=
'Discriminator to use for loss-type 2'
)
"--bag_of_words"
,
parser
.
add_argument
(
'--discrim_weights'
,
type
=
str
,
default
=
None
,
"-B"
,
help
=
'Weights for the generic discriminator'
)
type
=
str
,
parser
.
add_argument
(
'--discrim_meta'
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
'Meta information for the generic discriminator'
)
help
=
"Bags of words used for PPLM-BoW. Either a BOW id (see list in code) or a filepath. Multiple BoWs separated by ;"
,
parser
.
add_argument
(
'--label_class'
,
type
=
int
,
default
=-
1
,
)
help
=
'Class label used for the discriminator'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--stepsize'
,
type
=
float
,
default
=
0.02
)
"--discrim"
,
"-D"
,
type
=
str
,
default
=
None
,
choices
=
(
"clickbait"
,
"sentiment"
,
"toxicity"
),
help
=
"Discriminator to use for loss-type 2"
,
)
parser
.
add_argument
(
"--label_class"
,
type
=
int
,
default
=-
1
,
help
=
"Class label used for the discriminator"
,
)
parser
.
add_argument
(
"--stepsize"
,
type
=
float
,
default
=
0.02
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--gm_scale"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--gm_scale"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--kl_scale"
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
"--kl_scale"
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
"--no_cuda"
,
action
=
"store_true"
,
help
=
"no cuda"
)
parser
.
add_argument
(
'--nocuda'
,
action
=
'store_true'
,
help
=
'no cuda'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--uncond'
,
action
=
'store_true'
,
"--uncond"
,
action
=
"store_true"
,
help
=
'Generate from end-of-text as prefix'
)
help
=
"Generate from end-of-text as prefix"
parser
.
add_argument
(
"--cond_text"
,
type
=
str
,
default
=
'The lake'
,
)
help
=
'Prefix texts to condition on'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num_iterations'
,
type
=
int
,
default
=
3
)
"--cond_text"
,
type
=
str
,
default
=
"The lake"
,
parser
.
add_argument
(
'--grad_length'
,
type
=
int
,
default
=
10000
)
help
=
"Prefix texts to condition on"
parser
.
add_argument
(
'--num_samples'
,
type
=
int
,
default
=
1
,
)
help
=
'Number of samples to generate from the modified latents'
)
parser
.
add_argument
(
"--num_iterations"
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
'--horizon_length'
,
type
=
int
,
default
=
1
,
parser
.
add_argument
(
"--grad_length"
,
type
=
int
,
default
=
10000
)
help
=
'Length of future to optimize over'
)
parser
.
add_argument
(
# parser.add_argument('--force-token', action='store_true', help='no cuda')
"--num_samples"
,
parser
.
add_argument
(
'--window_length'
,
type
=
int
,
default
=
0
,
type
=
int
,
help
=
'Length of past which is being optimizer; 0 corresponds to infinite window length'
)
default
=
1
,
parser
.
add_argument
(
'--decay'
,
action
=
'store_true'
,
help
=
"Number of samples to generate from the modified latents"
,
help
=
'whether to decay or not'
)
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
1.5
)
parser
.
add_argument
(
parser
.
add_argument
(
'--colorama'
,
action
=
'store_true'
,
help
=
'no cuda'
)
"--horizon_length"
,
type
=
int
,
default
=
1
,
help
=
"Length of future to optimize over"
,
)
parser
.
add_argument
(
"--window_length"
,
type
=
int
,
default
=
0
,
help
=
"Length of past which is being optimized; "
"0 corresponds to infinite window length"
,
)
parser
.
add_argument
(
"--decay"
,
action
=
"store_true"
,
help
=
"whether to decay or not"
)
parser
.
add_argument
(
"--gamma"
,
type
=
float
,
default
=
1.5
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# set Random seed
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
# set the device
device
=
'cpu'
if
args
.
nocuda
else
'cuda'
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
# load pretrained model
model
=
GPT2LMHeadModel
.
from_pretrained
(
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_path
,
args
.
model_path
,
output_hidden_states
=
True
output_hidden_states
=
True
...
@@ -720,63 +673,82 @@ def run_model():
...
@@ -720,63 +673,82 @@ def run_model():
model
.
to
(
device
)
model
.
to
(
device
)
model
.
eval
()
model
.
eval
()
#
f
reeze GPT-2 weights
#
F
reeze GPT-2 weights
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
param
.
requires_grad
=
False
param
.
requires_grad
=
False
pass
# figure out conditioning text
if
args
.
uncond
:
if
args
.
uncond
:
tokenized_cond_text
=
TOKENIZER
.
encode
(
seq
=
[[
50256
,
50256
]]
[
TOKENIZER
.
bos_token
]
)
else
:
else
:
raw_text
=
args
.
cond_text
raw_text
=
args
.
cond_text
while
not
raw_text
:
while
not
raw_text
:
print
(
"
Did you forget to add `--cond
_
text`?
"
)
print
(
'
Did you forget to add `--cond
-
text`?
'
)
raw_text
=
input
(
"Model prompt >>> "
)
raw_text
=
input
(
"Model prompt >>> "
)
tokenized_cond_text
=
TOKENIZER
.
encode
(
TOKENIZER
.
bos_token
+
raw_text
)
seq
=
[[
50256
]
+
TOKENIZER
.
encode
(
raw_text
)]
print
(
"= Prefix of sentence ="
)
collect_gen
=
dict
()
print
(
TOKENIZER
.
decode
(
tokenized_cond_text
))
current_index
=
0
print
()
for
out
in
seq
:
# generate unperturbed and perturbed texts
text
=
TOKENIZER
.
decode
(
out
)
print
(
"="
*
40
+
" Prefix of sentence "
+
"="
*
40
)
# full_text_generation returns:
print
(
text
)
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
print
(
"="
*
80
)
unpert_gen_tok_text
,
pert_gen_tok_texts
,
_
,
_
=
full_text_generation
(
model
=
model
,
context
=
tokenized_cond_text
,
device
=
device
,
**
vars
(
args
)
out1
,
out_perturb
,
discrim_loss_list
,
loss_in_time_list
,
actual_words
=
latent_perturb
(
)
model
=
model
,
args
=
args
,
context
=
out
,
device
=
device
)
# untokenize unperturbed text
unpert_gen_text
=
TOKENIZER
.
decode
(
unpert_gen_tok_text
.
tolist
()[
0
])
text_whole
=
TOKENIZER
.
decode
(
out1
.
tolist
()[
0
])
print
(
"="
*
80
)
print
(
"="
*
80
)
print
(
"= Unperturbed generated text ="
)
print
(
"="
*
40
+
" Whole sentence (Original)"
+
"="
*
40
)
print
(
unpert_gen_text
)
print
(
text_whole
)
print
()
print
(
"="
*
80
)
out_perturb_copy
=
out_perturb
for
out_perturb
in
out_perturb_copy
:
# try:
# print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
# text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
# print(text_whole)
# print("=" * 80)
# except:
# pass
# collect_gen[current_index] = [out, out_perturb, out1]
## Save the prefix, perturbed seq, original seq for each index
print
(
"="
*
40
+
" Whole sentence (Perturbed)"
+
"="
*
40
)
keyword_tokens
=
[
aa
[
-
1
][
0
]
for
aa
in
actual_words
]
if
actual_words
else
[]
output_tokens
=
out_perturb
.
tolist
()[
0
]
if
args
.
colorama
:
import
colorama
text_whole
=
''
for
out
in
output_tokens
:
if
out
in
keyword_tokens
:
text_whole
+=
'%s%s%s'
%
(
colorama
.
Fore
.
GREEN
,
TOKENIZER
.
decode
([
out
]),
colorama
.
Style
.
RESET_ALL
)
else
:
text_whole
+=
TOKENIZER
.
decode
([
out
])
else
:
text_whole
=
TOKENIZER
.
decode
(
out_perturb
.
tolist
()[
0
])
generated_texts
=
[]
print
(
text_whole
)
print
(
"="
*
80
)
# iterate through the perturbed texts
collect_gen
[
current_index
]
=
[
out
,
out_perturb
,
out1
]
for
i
,
pert_gen_tok_text
in
enumerate
(
pert_gen_tok_texts
):
try
:
# untokenize unperturbed text
unpert_gen_text
=
TOKENIZER
.
decode
(
pert_gen_tok_text
.
tolist
()[
0
])
print
(
"= Perturbed generated text {} ="
.
format
(
i
+
1
))
current_index
=
current_index
+
1
print
(
unpert_gen_text
)
print
()
except
:
pass
# keep the prefix, perturbed seq, original seq for each index
generated_texts
.
append
(
(
tokenized_cond_text
,
pert_gen_tok_text
,
unpert_gen_tok_text
)
)
return
generated_texts
return
if
__name__
==
"
__main__
"
:
if
__name__
==
'
__main__
'
:
run_model
()
run_model
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment