Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
48a05026
Commit
48a05026
authored
May 28, 2020
by
prajjwal1
Committed by
Julien Chaumond
Jun 04, 2020
Browse files
removed deprecared use of Variable api from pplm example
parent
12d0eb5f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
24 deletions
+17
-24
examples/text-generation/pplm/run_pplm.py
examples/text-generation/pplm/run_pplm.py
+17
-24
No files found.
examples/text-generation/pplm/run_pplm.py
View file @
48a05026
...
@@ -31,7 +31,6 @@ from typing import List, Optional, Tuple, Union
...
@@ -31,7 +31,6 @@ from typing import List, Optional, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.autograd
import
Variable
from
tqdm
import
trange
from
tqdm
import
trange
from
pplm_classification_head
import
ClassificationHead
from
pplm_classification_head
import
ClassificationHead
...
@@ -76,14 +75,6 @@ DISCRIMINATOR_MODELS_PARAMS = {
...
@@ -76,14 +75,6 @@ DISCRIMINATOR_MODELS_PARAMS = {
}
}
def
to_var
(
x
,
requires_grad
=
False
,
volatile
=
False
,
device
=
"cuda"
):
if
torch
.
cuda
.
is_available
()
and
device
==
"cuda"
:
x
=
x
.
cuda
()
elif
device
!=
"cuda"
:
x
=
x
.
to
(
device
)
return
Variable
(
x
,
requires_grad
=
requires_grad
,
volatile
=
volatile
)
def
top_k_filter
(
logits
,
k
,
probs
=
False
):
def
top_k_filter
(
logits
,
k
,
probs
=
False
):
"""
"""
Masks everything but the k top entries as -infinity (1e10).
Masks everything but the k top entries as -infinity (1e10).
...
@@ -156,9 +147,7 @@ def perturb_past(
...
@@ -156,9 +147,7 @@ def perturb_past(
new_accumulated_hidden
=
None
new_accumulated_hidden
=
None
for
i
in
range
(
num_iterations
):
for
i
in
range
(
num_iterations
):
print
(
"Iteration "
,
i
+
1
)
print
(
"Iteration "
,
i
+
1
)
curr_perturbation
=
[
curr_perturbation
=
[
torch
.
from_numpy
(
p_
).
requires_grad_
(
True
).
to
(
device
=
device
)
for
p_
in
grad_accumulator
]
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
,
device
=
device
)
for
p_
in
grad_accumulator
]
# Compute hidden using perturbed past
# Compute hidden using perturbed past
perturbed_past
=
list
(
map
(
add
,
past
,
curr_perturbation
))
perturbed_past
=
list
(
map
(
add
,
past
,
curr_perturbation
))
...
@@ -247,7 +236,7 @@ def perturb_past(
...
@@ -247,7 +236,7 @@ def perturb_past(
past
=
new_past
past
=
new_past
# apply the accumulated perturbations to the past
# apply the accumulated perturbations to the past
grad_accumulator
=
[
to_var
(
torch
.
from_numpy
(
p_
)
,
requires_grad
=
True
,
device
=
device
)
for
p_
in
grad_accumulator
]
grad_accumulator
=
[
torch
.
from_numpy
(
p_
)
.
requires_grad
_
(
True
).
to
(
device
=
device
)
for
p_
in
grad_accumulator
]
pert_past
=
list
(
map
(
add
,
past
,
grad_accumulator
))
pert_past
=
list
(
map
(
add
,
past
,
grad_accumulator
))
return
pert_past
,
new_accumulated_hidden
,
grad_norms
,
loss_per_iter
return
pert_past
,
new_accumulated_hidden
,
grad_norms
,
loss_per_iter
...
@@ -266,7 +255,7 @@ def get_classifier(
...
@@ -266,7 +255,7 @@ def get_classifier(
elif
"path"
in
params
:
elif
"path"
in
params
:
resolved_archive_file
=
params
[
"path"
]
resolved_archive_file
=
params
[
"path"
]
else
:
else
:
raise
ValueError
(
"Either url or path have to be specified
"
"
in the discriminator model parameters"
)
raise
ValueError
(
"Either url or path have to be specified in the discriminator model parameters"
)
classifier
.
load_state_dict
(
torch
.
load
(
resolved_archive_file
,
map_location
=
device
))
classifier
.
load_state_dict
(
torch
.
load
(
resolved_archive_file
,
map_location
=
device
))
classifier
.
eval
()
classifier
.
eval
()
...
@@ -569,9 +558,9 @@ def generate_text_pplm(
...
@@ -569,9 +558,9 @@ def generate_text_pplm(
def
set_generic_model_params
(
discrim_weights
,
discrim_meta
):
def
set_generic_model_params
(
discrim_weights
,
discrim_meta
):
if
discrim_weights
is
None
:
if
discrim_weights
is
None
:
raise
ValueError
(
"When using a generic discriminator,
"
"
discrim_weights need to be specified"
)
raise
ValueError
(
"When using a generic discriminator, discrim_weights need to be specified"
)
if
discrim_meta
is
None
:
if
discrim_meta
is
None
:
raise
ValueError
(
"When using a generic discriminator,
"
"
discrim_meta need to be specified"
)
raise
ValueError
(
"When using a generic discriminator, discrim_meta need to be specified"
)
with
open
(
discrim_meta
,
"r"
)
as
discrim_meta_file
:
with
open
(
discrim_meta
,
"r"
)
as
discrim_meta_file
:
meta
=
json
.
load
(
discrim_meta_file
)
meta
=
json
.
load
(
discrim_meta_file
)
...
@@ -619,7 +608,7 @@ def run_pplm_example(
...
@@ -619,7 +608,7 @@ def run_pplm_example(
if
discrim
is
not
None
:
if
discrim
is
not
None
:
pretrained_model
=
DISCRIMINATOR_MODELS_PARAMS
[
discrim
][
"pretrained_model"
]
pretrained_model
=
DISCRIMINATOR_MODELS_PARAMS
[
discrim
][
"pretrained_model"
]
print
(
"discrim = {}, pretrained_model set
"
"
to discriminator's = {}"
.
format
(
discrim
,
pretrained_model
))
print
(
"discrim = {}, pretrained_model set to discriminator's = {}"
.
format
(
discrim
,
pretrained_model
))
# load pretrained model
# load pretrained model
model
=
GPT2LMHeadModel
.
from_pretrained
(
pretrained_model
,
output_hidden_states
=
True
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
pretrained_model
,
output_hidden_states
=
True
)
...
@@ -706,7 +695,7 @@ def run_pplm_example(
...
@@ -706,7 +695,7 @@ def run_pplm_example(
for
word_id
in
pert_gen_tok_text
.
tolist
()[
0
]:
for
word_id
in
pert_gen_tok_text
.
tolist
()[
0
]:
if
word_id
in
bow_word_ids
:
if
word_id
in
bow_word_ids
:
pert_gen_text
+=
"{}{}{}"
.
format
(
pert_gen_text
+=
"{}{}{}"
.
format
(
colorama
.
Fore
.
RED
,
tokenizer
.
decode
([
word_id
]),
colorama
.
Style
.
RESET_ALL
colorama
.
Fore
.
RED
,
tokenizer
.
decode
([
word_id
]),
colorama
.
Style
.
RESET_ALL
,
)
)
else
:
else
:
pert_gen_text
+=
tokenizer
.
decode
([
word_id
])
pert_gen_text
+=
tokenizer
.
decode
([
word_id
])
...
@@ -744,9 +733,11 @@ if __name__ == "__main__":
...
@@ -744,9 +733,11 @@ if __name__ == "__main__":
"-B"
,
"-B"
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
"Bags of words used for PPLM-BoW. "
help
=
(
"Either a BOW id (see list in code) or a filepath. "
"Bags of words used for PPLM-BoW. "
"Multiple BoWs separated by ;"
,
"Either a BOW id (see list in code) or a filepath. "
"Multiple BoWs separated by ;"
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--discrim"
,
"--discrim"
,
...
@@ -756,9 +747,11 @@ if __name__ == "__main__":
...
@@ -756,9 +747,11 @@ if __name__ == "__main__":
choices
=
(
"clickbait"
,
"sentiment"
,
"toxicity"
,
"generic"
),
choices
=
(
"clickbait"
,
"sentiment"
,
"toxicity"
,
"generic"
),
help
=
"Discriminator to use"
,
help
=
"Discriminator to use"
,
)
)
parser
.
add_argument
(
"--discrim_weights"
,
type
=
str
,
default
=
None
,
help
=
"Weights for the generic discriminator"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--discrim_meta"
,
type
=
str
,
default
=
None
,
help
=
"Meta information for the generic discriminator"
"--discrim_weights"
,
type
=
str
,
default
=
None
,
help
=
"Weights for the generic discriminator"
,
)
parser
.
add_argument
(
"--discrim_meta"
,
type
=
str
,
default
=
None
,
help
=
"Meta information for the generic discriminator"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--class_label"
,
type
=
int
,
default
=-
1
,
help
=
"Class label used for the discriminator"
,
"--class_label"
,
type
=
int
,
default
=-
1
,
help
=
"Class label used for the discriminator"
,
...
@@ -774,7 +767,7 @@ if __name__ == "__main__":
...
@@ -774,7 +767,7 @@ if __name__ == "__main__":
"--window_length"
,
"--window_length"
,
type
=
int
,
type
=
int
,
default
=
0
,
default
=
0
,
help
=
"Length of past which is being optimized;
"
"
0 corresponds to infinite window length"
,
help
=
"Length of past which is being optimized; 0 corresponds to infinite window length"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--horizon_length"
,
type
=
int
,
default
=
1
,
help
=
"Length of future to optimize over"
,
"--horizon_length"
,
type
=
int
,
default
=
1
,
help
=
"Length of future to optimize over"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment