Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
08c6e456
Commit
08c6e456
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Cleaned full_text_generation. Identical output as before.
parent
6c9c1317
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
84 deletions
+19
-84
examples/run_pplm.py
examples/run_pplm.py
+19
-84
No files found.
examples/run_pplm.py
View file @
08c6e456
...
@@ -401,74 +401,6 @@ def full_text_generation(
...
@@ -401,74 +401,6 @@ def full_text_generation(
device
device
)
)
# if args.discrim == 'clickbait':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt"))
# classifier.eval()
# args.label_class = 1 # clickbaity
#
# elif args.discrim == 'sentiment':
# classifier = ClassificationHead(class_size=5, embed_size=1024).to(device)
# #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt"))
# classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt"))
# classifier.eval()
# if args.label_class < 0:
# raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*')
# #args.label_class = 2 # very pos
# #args.label_class = 3 # very neg
#
# elif args.discrim == 'toxicity':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt"))
# classifier.eval()
# args.label_class = 0 # not toxic
#
# elif args.discrim == 'generic':
# if args.discrim_weights is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_weights need to be specified')
# if args.discrim_meta is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_meta need to be specified')
#
# with open(args.discrim_meta, 'r') as discrim_meta_file:
# meta = json.load(discrim_meta_file)
#
# classifier = ClassificationHead(
# class_size=meta['class_size'],
# embed_size=meta['embed_size'],
# # todo add tokenizer from meta
# ).to(device)
# classifier.load_state_dict(torch.load(args.discrim_weights))
# classifier.eval()
# if args.label_class == -1:
# args.label_class = meta['default_class']
#
# else:
# classifier = None
# Get tokens for the list of positive words
def
list_tokens
(
word_list
):
token_list
=
[
TOKENIZER
.
encode
(
word
,
add_prefix_space
=
True
)
for
word
in
word_list
]
# token_list = []
# for word in word_list:
# token_list.append(TOKENIZER.encode(" " + word))
return
token_list
# good_index = []
# if args.bag_of_words:
# bags_of_words = args.bag_of_words.split(";")
# for wordlist in bags_of_words:
# with open(wordlist, "r") as f:
# words = f.read().strip()
# words = words.split('\n')
# good_index.append(list_tokens(words))
#
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
bow_indices
=
[]
bow_indices
=
[]
if
bag_of_words
:
if
bag_of_words
:
bow_indices
=
get_bag_of_words_indices
(
bag_of_words
.
split
(
";"
))
bow_indices
=
get_bag_of_words_indices
(
bag_of_words
.
split
(
";"
))
...
@@ -486,9 +418,9 @@ def full_text_generation(
...
@@ -486,9 +418,9 @@ def full_text_generation(
print
(
"Using PPLM-Discrim"
)
print
(
"Using PPLM-Discrim"
)
else
:
else
:
raise
Exception
(
"Specify either
--
bag
_
of
_
words
(-B)
or
--
discrim
(-D)
"
)
raise
Exception
(
"Specify either
a
bag
of
words or
a
discrim
inator
"
)
original
,
_
,
_
=
generate_text_pplm
(
unpert_gen_tok_text
,
_
,
_
=
generate_text_pplm
(
model
=
model
,
model
=
model
,
context
=
context
,
context
=
context
,
device
=
device
,
device
=
device
,
...
@@ -497,12 +429,12 @@ def full_text_generation(
...
@@ -497,12 +429,12 @@ def full_text_generation(
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
pert
urbed_list
=
[]
pert
_gen_tok_texts
=
[]
discrim_loss
_list
=
[]
discrim_loss
es
=
[]
loss_in_time
_list
=
[]
loss
es
_in_time
=
[]
for
i
in
range
(
num_samples
):
for
i
in
range
(
num_samples
):
pert
urbed
,
discrim_loss
,
loss_in_time
=
generate_text_pplm
(
pert
_gen_tok_text
,
discrim_loss
,
loss_in_time
=
generate_text_pplm
(
model
=
model
,
model
=
model
,
context
=
context
,
context
=
context
,
device
=
device
,
device
=
device
,
...
@@ -525,14 +457,14 @@ def full_text_generation(
...
@@ -525,14 +457,14 @@ def full_text_generation(
decay
=
decay
,
decay
=
decay
,
gamma
=
gamma
,
gamma
=
gamma
,
)
)
pert
urbed_list
.
append
(
perturbed
)
pert
_gen_tok_texts
.
append
(
pert_gen_tok_text
)
if
classifier
is
not
None
:
if
classifier
is
not
None
:
discrim_loss
_list
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
discrim_loss
es
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
loss_in_time
_list
.
append
(
loss_in_time
)
loss
es
_in_time
.
append
(
loss_in_time
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
original
,
perturbed_list
,
discrim_loss
_list
,
loss_in_time
_list
return
unpert_gen_tok_text
,
pert_gen_tok_texts
,
discrim_loss
es
,
loss
es
_in_time
def
generate_text_pplm
(
def
generate_text_pplm
(
...
@@ -821,11 +753,14 @@ def run_model():
...
@@ -821,11 +753,14 @@ def run_model():
generated_texts
=
[]
generated_texts
=
[]
bow_words
=
set
()
bow_word_ids
=
set
()
bow_indices
=
get_bag_of_words_indices
(
args
.
bag_of_words
.
split
(
";"
))
if
args
.
bag_of_words
and
args
.
colorama
:
for
bow_list
in
bow_indices
:
bow_indices
=
get_bag_of_words_indices
(
args
.
bag_of_words
.
split
(
";"
))
filtered
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
bow_list
))
for
single_bow_list
in
bow_indices
:
bow_words
.
update
(
w
[
0
]
for
w
in
filtered
)
# filtering all words in the list composed of more than 1 token
filtered
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
single_bow_list
))
# w[0] because we are sure w has only 1 item because previous fitler
bow_word_ids
.
update
(
w
[
0
]
for
w
in
filtered
)
# iterate through the perturbed texts
# iterate through the perturbed texts
for
i
,
pert_gen_tok_text
in
enumerate
(
pert_gen_tok_texts
):
for
i
,
pert_gen_tok_text
in
enumerate
(
pert_gen_tok_texts
):
...
@@ -836,7 +771,7 @@ def run_model():
...
@@ -836,7 +771,7 @@ def run_model():
pert_gen_text
=
''
pert_gen_text
=
''
for
word_id
in
pert_gen_tok_text
.
tolist
()[
0
]:
for
word_id
in
pert_gen_tok_text
.
tolist
()[
0
]:
if
word_id
in
bow_words
:
if
word_id
in
bow_word
_id
s
:
pert_gen_text
+=
'{}{}{}'
.
format
(
pert_gen_text
+=
'{}{}{}'
.
format
(
colorama
.
Fore
.
RED
,
colorama
.
Fore
.
RED
,
TOKENIZER
.
decode
([
word_id
]),
TOKENIZER
.
decode
([
word_id
]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment