Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6c9c1317
Commit
6c9c1317
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
More cleanup for run_model. Identical output as before.
parent
7ffe47c8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
174 additions
and
140 deletions
+174
-140
examples/run_pplm.py
examples/run_pplm.py
+174
-140
No files found.
examples/run_pplm.py
View file @
6c9c1317
...
@@ -39,7 +39,6 @@ from transformers import GPT2Tokenizer
...
@@ -39,7 +39,6 @@ from transformers import GPT2Tokenizer
from
transformers.file_utils
import
cached_path
from
transformers.file_utils
import
cached_path
from
transformers.modeling_gpt2
import
GPT2LMHeadModel
from
transformers.modeling_gpt2
import
GPT2LMHeadModel
PPLM_BOW
=
1
PPLM_BOW
=
1
PPLM_DISCRIM
=
2
PPLM_DISCRIM
=
2
PPLM_BOW_DISCRIM
=
3
PPLM_BOW_DISCRIM
=
3
...
@@ -129,8 +128,7 @@ def perturb_past(
...
@@ -129,8 +128,7 @@ def perturb_past(
decay
=
False
,
decay
=
False
,
gamma
=
1.5
,
gamma
=
1.5
,
):
):
# def perturb_past(past, model, prev, classifier, good_index=None,
#def perturb_past(past, model, prev, classifier, good_index=None,
# stepsize=0.01, vocab_size=50257,
# stepsize=0.01, vocab_size=50257,
# original_probs=None, accumulated_hidden=None, true_past=None,
# original_probs=None, accumulated_hidden=None, true_past=None,
# grad_norms=None):
# grad_norms=None):
...
@@ -237,7 +235,7 @@ def perturb_past(
...
@@ -237,7 +235,7 @@ def perturb_past(
future_hidden
,
dim
=
1
)
future_hidden
,
dim
=
1
)
predicted_sentiment
=
classifier
(
new_accumulated_hidden
/
(
predicted_sentiment
=
classifier
(
new_accumulated_hidden
/
(
current_length
+
1
+
horizon_length
))
current_length
+
1
+
horizon_length
))
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
...
@@ -349,6 +347,13 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
...
@@ -349,6 +347,13 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
bow_indices
.
append
(
bow_indices
.
append
(
[
TOKENIZER
.
encode
(
word
.
strip
(),
add_prefix_space
=
True
)
for
word
in
[
TOKENIZER
.
encode
(
word
.
strip
(),
add_prefix_space
=
True
)
for
word
in
words
])
words
])
#bow_words = set()
#for bow_list in bow_indices:
# bow_list = list(filter(lambda x: len(x) <= 1, bow_list))
# bow_words.update(
# (TOKENIZER.decode(word).strip(), word) for word in bow_list)
return
bow_indices
return
bow_indices
...
@@ -368,28 +373,28 @@ def build_bows_one_hot_vectors(bow_indices):
...
@@ -368,28 +373,28 @@ def build_bows_one_hot_vectors(bow_indices):
def
full_text_generation
(
def
full_text_generation
(
model
,
model
,
context
=
None
,
context
=
None
,
num_samples
=
1
,
num_samples
=
1
,
device
=
"cuda"
,
device
=
"cuda"
,
sample
=
True
,
sample
=
True
,
discrim
=
None
,
discrim
=
None
,
label_class
=
None
,
label_class
=
None
,
bag_of_words
=
None
,
bag_of_words
=
None
,
length
=
100
,
length
=
100
,
grad_length
=
10000
,
grad_length
=
10000
,
stepsize
=
0.02
,
stepsize
=
0.02
,
num_iterations
=
3
,
num_iterations
=
3
,
temperature
=
1.0
,
temperature
=
1.0
,
gm_scale
=
0.9
,
gm_scale
=
0.9
,
kl_scale
=
0.01
,
kl_scale
=
0.01
,
top_k
=
10
,
top_k
=
10
,
window_length
=
0
,
window_length
=
0
,
horizon_length
=
1
,
horizon_length
=
1
,
decay
=
False
,
decay
=
False
,
gamma
=
1.5
,
gamma
=
1.5
,
**
kwargs
**
kwargs
):
):
classifier
,
class_id
=
get_classifier
(
classifier
,
class_id
=
get_classifier
(
discrim
,
discrim
,
label_class
,
label_class
,
...
@@ -465,15 +470,9 @@ def full_text_generation(
...
@@ -465,15 +470,9 @@ def full_text_generation(
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
bow_indices
=
[]
bow_indices
=
[]
actual_words
=
None
if
bag_of_words
:
if
bag_of_words
:
bow_indices
=
get_bag_of_words_indices
(
bag_of_words
.
split
(
";"
))
bow_indices
=
get_bag_of_words_indices
(
bag_of_words
.
split
(
";"
))
for
good_list
in
bow_indices
:
good_list
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
good_list
))
actual_words
=
[(
TOKENIZER
.
decode
(
ww
).
strip
(),
ww
)
for
ww
in
good_list
]
if
bag_of_words
and
classifier
:
if
bag_of_words
and
classifier
:
print
(
"Both PPLM-BoW and PPLM-Discrim are on. This is not optimized."
)
print
(
"Both PPLM-BoW and PPLM-Discrim are on. This is not optimized."
)
loss_type
=
PPLM_BOW_DISCRIM
loss_type
=
PPLM_BOW_DISCRIM
...
@@ -533,8 +532,7 @@ def full_text_generation(
...
@@ -533,8 +532,7 @@ def full_text_generation(
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
original
,
perturbed_list
,
discrim_loss_list
,
loss_in_time_list
,
actual_words
return
original
,
perturbed_list
,
discrim_loss_list
,
loss_in_time_list
def
generate_text_pplm
(
def
generate_text_pplm
(
...
@@ -611,25 +609,25 @@ def generate_text_pplm(
...
@@ -611,25 +609,25 @@ def generate_text_pplm(
accumulated_hidden
=
torch
.
sum
(
accumulated_hidden
,
dim
=
1
)
accumulated_hidden
=
torch
.
sum
(
accumulated_hidden
,
dim
=
1
)
perturbed_past
,
_
,
grad_norms
,
loss_per_iter
=
perturb_past
(
perturbed_past
,
_
,
grad_norms
,
loss_per_iter
=
perturb_past
(
past
,
past
,
model
,
model
,
prev
,
prev
,
unpert_past
=
unpert_past
,
unpert_past
=
unpert_past
,
unpert_logits
=
unpert_logits
,
unpert_logits
=
unpert_logits
,
accumulated_hidden
=
accumulated_hidden
,
accumulated_hidden
=
accumulated_hidden
,
grad_norms
=
grad_norms
,
grad_norms
=
grad_norms
,
stepsize
=
current_stepsize
,
stepsize
=
current_stepsize
,
classifier
=
classifier
,
classifier
=
classifier
,
label_class
=
label_class
,
label_class
=
label_class
,
one_hot_bows_vectors
=
one_hot_bows_vectors
,
one_hot_bows_vectors
=
one_hot_bows_vectors
,
loss_type
=
loss_type
,
loss_type
=
loss_type
,
num_iterations
=
num_iterations
,
num_iterations
=
num_iterations
,
kl_scale
=
kl_scale
,
kl_scale
=
kl_scale
,
window_length
=
window_length
,
window_length
=
window_length
,
horizon_length
=
horizon_length
,
horizon_length
=
horizon_length
,
decay
=
decay
,
decay
=
decay
,
gamma
=
gamma
,
gamma
=
gamma
,
)
)
loss_in_time
.
append
(
loss_per_iter
)
loss_in_time
.
append
(
loss_per_iter
)
# Piero modified model call
# Piero modified model call
...
@@ -666,7 +664,7 @@ def generate_text_pplm(
...
@@ -666,7 +664,7 @@ def generate_text_pplm(
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
log_probs
=
((
log_probs
**
gm_scale
)
*
(
log_probs
=
((
log_probs
**
gm_scale
)
*
(
unpert_logits
**
(
1
-
gm_scale
)))
# + SmallConst
unpert_logits
**
(
1
-
gm_scale
)))
# + SmallConst
log_probs
=
top_k_filter
(
log_probs
,
k
=
top_k
,
log_probs
=
top_k_filter
(
log_probs
,
k
=
top_k
,
probs
=
True
)
# + SmallConst
probs
=
True
)
# + SmallConst
...
@@ -696,53 +694,88 @@ def generate_text_pplm(
...
@@ -696,53 +694,88 @@ def generate_text_pplm(
def
run_model
():
def
run_model
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_path'
,
'-M'
,
type
=
str
,
default
=
'gpt2-medium'
,
parser
.
add_argument
(
help
=
'pretrained model name or path to local checkpoint'
)
"--model_path"
,
parser
.
add_argument
(
'--bag-of-words'
,
'-B'
,
type
=
str
,
default
=
None
,
"-M"
,
help
=
'Bags of words used for PPLM-BoW. Multiple BoWs separated by ;'
)
type
=
str
,
parser
.
add_argument
(
'--discrim'
,
'-D'
,
type
=
str
,
default
=
None
,
default
=
"gpt2-medium"
,
choices
=
(
help
=
"pretrained model name or path to local checkpoint"
,
'clickbait'
,
'sentiment'
,
'toxicity'
,
'generic'
),
)
help
=
'Discriminator to use for loss-type 2'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--discrim_weights'
,
type
=
str
,
default
=
None
,
"--bag_of_words"
,
help
=
'Weights for the generic discriminator'
)
"-B"
,
parser
.
add_argument
(
'--discrim_meta'
,
type
=
str
,
default
=
None
,
type
=
str
,
help
=
'Meta information for the generic discriminator'
)
default
=
None
,
parser
.
add_argument
(
'--label_class'
,
type
=
int
,
default
=-
1
,
help
=
"Bags of words used for PPLM-BoW. "
help
=
'Class label used for the discriminator'
)
"Either a BOW id (see list in code) or a filepath. "
parser
.
add_argument
(
'--stepsize'
,
type
=
float
,
default
=
0.02
)
"Multiple BoWs separated by ;"
,
)
parser
.
add_argument
(
"--discrim"
,
"-D"
,
type
=
str
,
default
=
None
,
choices
=
(
"clickbait"
,
"sentiment"
,
"toxicity"
),
help
=
"Discriminator to use for loss-type 2"
,
)
parser
.
add_argument
(
"--label_class"
,
type
=
int
,
default
=-
1
,
help
=
"Class label used for the discriminator"
,
)
parser
.
add_argument
(
"--stepsize"
,
type
=
float
,
default
=
0.02
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--gm_scale"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--gm_scale"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--kl_scale"
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
"--kl_scale"
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
'--nocuda'
,
action
=
'store_true'
,
help
=
'no cuda'
)
parser
.
add_argument
(
"--no_cuda"
,
action
=
"store_true"
,
help
=
"no cuda"
)
parser
.
add_argument
(
'--uncond'
,
action
=
'store_true'
,
parser
.
add_argument
(
help
=
'Generate from end-of-text as prefix'
)
"--uncond"
,
action
=
"store_true"
,
parser
.
add_argument
(
"--cond_text"
,
type
=
str
,
default
=
'The lake'
,
help
=
"Generate from end-of-text as prefix"
help
=
'Prefix texts to condition on'
)
)
parser
.
add_argument
(
'--num_iterations'
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
parser
.
add_argument
(
'--grad_length'
,
type
=
int
,
default
=
10000
)
"--cond_text"
,
type
=
str
,
default
=
"The lake"
,
parser
.
add_argument
(
'--num_samples'
,
type
=
int
,
default
=
1
,
help
=
"Prefix texts to condition on"
help
=
'Number of samples to generate from the modified latents'
)
)
parser
.
add_argument
(
'--horizon_length'
,
type
=
int
,
default
=
1
,
parser
.
add_argument
(
"--num_iterations"
,
type
=
int
,
default
=
3
)
help
=
'Length of future to optimize over'
)
parser
.
add_argument
(
"--grad_length"
,
type
=
int
,
default
=
10000
)
# parser.add_argument('--force-token', action='store_true', help='no cuda')
parser
.
add_argument
(
parser
.
add_argument
(
'--window_length'
,
type
=
int
,
default
=
0
,
"--num_samples"
,
help
=
'Length of past which is being optimizer; 0 corresponds to infinite window length'
)
type
=
int
,
parser
.
add_argument
(
'--decay'
,
action
=
'store_true'
,
default
=
1
,
help
=
'whether to decay or not'
)
help
=
"Number of samples to generate from the modified latents"
,
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
1.5
)
)
parser
.
add_argument
(
'--colorama'
,
action
=
'store_true'
,
help
=
'no cuda'
)
parser
.
add_argument
(
"--horizon_length"
,
type
=
int
,
default
=
1
,
help
=
"Length of future to optimize over"
,
)
parser
.
add_argument
(
"--window_length"
,
type
=
int
,
default
=
0
,
help
=
"Length of past which is being optimized; "
"0 corresponds to infinite window length"
,
)
parser
.
add_argument
(
"--decay"
,
action
=
"store_true"
,
help
=
"whether to decay or not"
)
parser
.
add_argument
(
"--gamma"
,
type
=
float
,
default
=
1.5
)
parser
.
add_argument
(
"--colorama"
,
action
=
"store_true"
,
help
=
"colors keywords"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# set Random seed
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
device
=
'cpu'
if
args
.
nocuda
else
'cuda'
# set the device
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
# load pretrained model
model
=
GPT2LMHeadModel
.
from_pretrained
(
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_path
,
args
.
model_path
,
output_hidden_states
=
True
output_hidden_states
=
True
...
@@ -753,76 +786,77 @@ def run_model():
...
@@ -753,76 +786,77 @@ def run_model():
# Freeze GPT-2 weights
# Freeze GPT-2 weights
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
param
.
requires_grad
=
False
param
.
requires_grad
=
False
pass
# figure out conditioning text
if
args
.
uncond
:
if
args
.
uncond
:
seq
=
[[
50256
,
50256
]]
tokenized_cond_text
=
TOKENIZER
.
encode
(
[
TOKENIZER
.
bos_token
]
)
else
:
else
:
raw_text
=
args
.
cond_text
raw_text
=
args
.
cond_text
while
not
raw_text
:
while
not
raw_text
:
print
(
'
Did you forget to add `--cond
-
text`?
'
)
print
(
"
Did you forget to add `--cond
_
text`?
"
)
raw_text
=
input
(
"Model prompt >>> "
)
raw_text
=
input
(
"Model prompt >>> "
)
seq
=
[[
50256
]
+
TOKENIZER
.
encode
(
raw_text
)
]
tokenized_cond_text
=
TOKENIZER
.
encode
(
TOKENIZER
.
bos_token
+
raw_text
)
collect_gen
=
dict
(
)
print
(
"= Prefix of sentence ="
)
current_index
=
0
print
(
TOKENIZER
.
decode
(
tokenized_cond_text
))
for
tokenized_cond_text
in
seq
:
print
()
text
=
TOKENIZER
.
decode
(
tokenized_cond_text
)
# generate unperturbed and perturbed texts
print
(
"="
*
40
+
" Prefix of sentence "
+
"="
*
40
)
print
(
text
)
print
(
"="
*
80
)
out1
,
out_perturb
,
discrim_loss_list
,
loss_in_time_list
,
actual_words
=
full_text_generation
(
# full_text_generation returns:
model
=
model
,
context
=
tokenized_cond_text
,
device
=
device
,
**
vars
(
args
)
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
)
unpert_gen_tok_text
,
pert_gen_tok_texts
,
_
,
_
=
full_text_generation
(
model
=
model
,
context
=
tokenized_cond_text
,
device
=
device
,
**
vars
(
args
)
)
# untokenize unperturbed text
unpert_gen_text
=
TOKENIZER
.
decode
(
unpert_gen_tok_text
.
tolist
()[
0
])
text_whole
=
TOKENIZER
.
decode
(
out1
.
tolist
()[
0
])
print
(
"="
*
80
)
print
(
"= Unperturbed generated text ="
)
print
(
"="
*
80
)
print
(
unpert_gen_text
)
print
(
"="
*
40
+
" Whole sentence (Original)"
+
"="
*
40
)
print
()
print
(
text_whole
)
print
(
"="
*
80
)
out_perturb_copy
=
out_perturb
for
out_perturb
in
out_perturb_copy
:
# try:
# print("=" * 40 + " Whole sentence (Perturbed)" + "=" * 40)
# text_whole = TOKENIZER.decode(out_perturb.tolist()[0])
# print(text_whole)
# print("=" * 80)
# except:
# pass
# collect_gen[current_index] = [out, out_perturb, out1]
## Save the prefix, perturbed seq, original seq for each index
print
(
"="
*
40
+
" Whole sentence (Perturbed)"
+
"="
*
40
)
keyword_tokens
=
[
aa
[
-
1
][
0
]
for
aa
in
actual_words
]
if
actual_words
else
[]
output_tokens
=
out_perturb
.
tolist
()[
0
]
generated_texts
=
[]
bow_words
=
set
()
bow_indices
=
get_bag_of_words_indices
(
args
.
bag_of_words
.
split
(
";"
))
for
bow_list
in
bow_indices
:
filtered
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
bow_list
))
bow_words
.
update
(
w
[
0
]
for
w
in
filtered
)
# iterate through the perturbed texts
for
i
,
pert_gen_tok_text
in
enumerate
(
pert_gen_tok_texts
):
try
:
# untokenize unperturbed text
if
args
.
colorama
:
if
args
.
colorama
:
import
colorama
import
colorama
text_whole
=
''
pert_gen_text
=
''
for
tokenized_cond_text
in
output_tokens
:
for
word_id
in
pert_gen_tok_text
.
tolist
()[
0
]:
if
tokenized_cond_text
in
keyword_tokens
:
if
word_id
in
bow_words
:
text_whole
+=
'%s%s%s'
%
(
pert_gen_text
+=
'{}{}{}'
.
format
(
colorama
.
Fore
.
GREEN
,
TOKENIZER
.
decode
([
tokenized_cond_text
]),
colorama
.
Fore
.
RED
,
colorama
.
Style
.
RESET_ALL
)
TOKENIZER
.
decode
([
word_id
]),
colorama
.
Style
.
RESET_ALL
)
else
:
else
:
text_whole
+=
TOKENIZER
.
decode
([
tokenized_cond_text
])
pert_gen_text
+=
TOKENIZER
.
decode
([
word_id
])
else
:
else
:
text_whole
=
TOKENIZER
.
decode
(
out_perturb
.
tolist
()[
0
])
pert_gen_text
=
TOKENIZER
.
decode
(
pert_gen_tok_text
.
tolist
()[
0
])
print
(
text_whole
)
print
(
"="
*
80
)
collect_gen
[
current_index
]
=
[
tokenized_cond_text
,
out_perturb
,
out1
]
print
(
"= Perturbed generated text {} ="
.
format
(
i
+
1
))
print
(
pert_gen_text
)
current_index
=
current_index
+
1
print
()
except
:
pass
# keep the prefix, perturbed seq, original seq for each index
generated_texts
.
append
(
(
tokenized_cond_text
,
pert_gen_tok_text
,
unpert_gen_tok_text
)
)
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment