Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
08c6e456
Commit
08c6e456
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Cleaned full_text_generation. Identical output as before.
parent
6c9c1317
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
84 deletions
+19
-84
examples/run_pplm.py
examples/run_pplm.py
+19
-84
No files found.
examples/run_pplm.py
View file @
08c6e456
...
@@ -401,74 +401,6 @@ def full_text_generation(
...
@@ -401,74 +401,6 @@ def full_text_generation(
device
device
)
)
# if args.discrim == 'clickbait':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt"))
# classifier.eval()
# args.label_class = 1 # clickbaity
#
# elif args.discrim == 'sentiment':
# classifier = ClassificationHead(class_size=5, embed_size=1024).to(device)
# #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt"))
# classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt"))
# classifier.eval()
# if args.label_class < 0:
# raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*')
# #args.label_class = 2 # very pos
# #args.label_class = 3 # very neg
#
# elif args.discrim == 'toxicity':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt"))
# classifier.eval()
# args.label_class = 0 # not toxic
#
# elif args.discrim == 'generic':
# if args.discrim_weights is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_weights need to be specified')
# if args.discrim_meta is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_meta need to be specified')
#
# with open(args.discrim_meta, 'r') as discrim_meta_file:
# meta = json.load(discrim_meta_file)
#
# classifier = ClassificationHead(
# class_size=meta['class_size'],
# embed_size=meta['embed_size'],
# # todo add tokenizer from meta
# ).to(device)
# classifier.load_state_dict(torch.load(args.discrim_weights))
# classifier.eval()
# if args.label_class == -1:
# args.label_class = meta['default_class']
#
# else:
# classifier = None
# Get tokens for the list of positive words
def
list_tokens
(
word_list
):
token_list
=
[
TOKENIZER
.
encode
(
word
,
add_prefix_space
=
True
)
for
word
in
word_list
]
# token_list = []
# for word in word_list:
# token_list.append(TOKENIZER.encode(" " + word))
return
token_list
# good_index = []
# if args.bag_of_words:
# bags_of_words = args.bag_of_words.split(";")
# for wordlist in bags_of_words:
# with open(wordlist, "r") as f:
# words = f.read().strip()
# words = words.split('\n')
# good_index.append(list_tokens(words))
#
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
bow_indices
=
[]
bow_indices
=
[]
if
bag_of_words
:
if
bag_of_words
:
bow_indices
=
get_bag_of_words_indices
(
bag_of_words
.
split
(
";"
))
bow_indices
=
get_bag_of_words_indices
(
bag_of_words
.
split
(
";"
))
...
@@ -486,9 +418,9 @@ def full_text_generation(
...
@@ -486,9 +418,9 @@ def full_text_generation(
print
(
"Using PPLM-Discrim"
)
print
(
"Using PPLM-Discrim"
)
else
:
else
:
raise
Exception
(
"Specify either
--
bag
_
of
_
words
(-B)
or
--
discrim
(-D)
"
)
raise
Exception
(
"Specify either
a
bag
of
words or
a
discrim
inator
"
)
original
,
_
,
_
=
generate_text_pplm
(
unpert_gen_tok_text
,
_
,
_
=
generate_text_pplm
(
model
=
model
,
model
=
model
,
context
=
context
,
context
=
context
,
device
=
device
,
device
=
device
,
...
@@ -497,12 +429,12 @@ def full_text_generation(
...
@@ -497,12 +429,12 @@ def full_text_generation(
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
pert
urbed_list
=
[]
pert
_gen_tok_texts
=
[]
discrim_loss
_list
=
[]
discrim_loss
es
=
[]
loss_in_time
_list
=
[]
loss
es
_in_time
=
[]
for
i
in
range
(
num_samples
):
for
i
in
range
(
num_samples
):
pert
urbed
,
discrim_loss
,
loss_in_time
=
generate_text_pplm
(
pert
_gen_tok_text
,
discrim_loss
,
loss_in_time
=
generate_text_pplm
(
model
=
model
,
model
=
model
,
context
=
context
,
context
=
context
,
device
=
device
,
device
=
device
,
...
@@ -525,14 +457,14 @@ def full_text_generation(
...
@@ -525,14 +457,14 @@ def full_text_generation(
decay
=
decay
,
decay
=
decay
,
gamma
=
gamma
,
gamma
=
gamma
,
)
)
pert
urbed_list
.
append
(
perturbed
)
pert
_gen_tok_texts
.
append
(
pert_gen_tok_text
)
if
classifier
is
not
None
:
if
classifier
is
not
None
:
discrim_loss
_list
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
discrim_loss
es
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
loss_in_time
_list
.
append
(
loss_in_time
)
loss
es
_in_time
.
append
(
loss_in_time
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
original
,
perturbed_list
,
discrim_loss
_list
,
loss_in_time
_list
return
unpert_gen_tok_text
,
pert_gen_tok_texts
,
discrim_loss
es
,
loss
es
_in_time
def
generate_text_pplm
(
def
generate_text_pplm
(
...
@@ -821,11 +753,14 @@ def run_model():
...
@@ -821,11 +753,14 @@ def run_model():
generated_texts
=
[]
generated_texts
=
[]
bow_words
=
set
()
bow_word_ids
=
set
()
if
args
.
bag_of_words
and
args
.
colorama
:
bow_indices
=
get_bag_of_words_indices
(
args
.
bag_of_words
.
split
(
";"
))
bow_indices
=
get_bag_of_words_indices
(
args
.
bag_of_words
.
split
(
";"
))
for
bow_list
in
bow_indices
:
for
single_bow_list
in
bow_indices
:
filtered
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
bow_list
))
# filtering all words in the list composed of more than 1 token
bow_words
.
update
(
w
[
0
]
for
w
in
filtered
)
filtered
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
single_bow_list
))
# w[0] because we are sure w has only 1 item because previous fitler
bow_word_ids
.
update
(
w
[
0
]
for
w
in
filtered
)
# iterate through the perturbed texts
# iterate through the perturbed texts
for
i
,
pert_gen_tok_text
in
enumerate
(
pert_gen_tok_texts
):
for
i
,
pert_gen_tok_text
in
enumerate
(
pert_gen_tok_texts
):
...
@@ -836,7 +771,7 @@ def run_model():
...
@@ -836,7 +771,7 @@ def run_model():
pert_gen_text
=
''
pert_gen_text
=
''
for
word_id
in
pert_gen_tok_text
.
tolist
()[
0
]:
for
word_id
in
pert_gen_tok_text
.
tolist
()[
0
]:
if
word_id
in
bow_words
:
if
word_id
in
bow_word
_id
s
:
pert_gen_text
+=
'{}{}{}'
.
format
(
pert_gen_text
+=
'{}{}{}'
.
format
(
colorama
.
Fore
.
RED
,
colorama
.
Fore
.
RED
,
TOKENIZER
.
decode
([
word_id
]),
TOKENIZER
.
decode
([
word_id
]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment