Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7fd54b55
Commit
7fd54b55
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Added support for generic discriminators
parent
b0eaff36
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
26 deletions
+51
-26
examples/run_pplm.py
examples/run_pplm.py
+51
-26
No files found.
examples/run_pplm.py
View file @
7fd54b55
...
@@ -14,17 +14,16 @@
...
@@ -14,17 +14,16 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# TODO: add code for training a custom discriminator
"""
"""
Example command with bag of words:
Example command with bag of words:
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
Example command with discriminator:
Example command with discriminator:
python examples/run_pplm.py -D sentiment --
label_class
3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
python examples/run_pplm.py -D sentiment --
class_label
3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
"""
"""
import
argparse
import
argparse
import
json
from
operator
import
add
from
operator
import
add
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
@@ -121,7 +120,7 @@ def perturb_past(
...
@@ -121,7 +120,7 @@ def perturb_past(
grad_norms
=
None
,
grad_norms
=
None
,
stepsize
=
0.01
,
stepsize
=
0.01
,
classifier
=
None
,
classifier
=
None
,
label_class
=
None
,
class_label
=
None
,
one_hot_bows_vectors
=
None
,
one_hot_bows_vectors
=
None
,
loss_type
=
0
,
loss_type
=
0
,
num_iterations
=
3
,
num_iterations
=
3
,
...
@@ -230,7 +229,7 @@ def perturb_past(
...
@@ -230,7 +229,7 @@ def perturb_past(
prediction
=
classifier
(
new_accumulated_hidden
/
prediction
=
classifier
(
new_accumulated_hidden
/
(
curr_length
+
1
+
horizon_length
))
(
curr_length
+
1
+
horizon_length
))
label
=
torch
.
tensor
([
label_class
],
device
=
device
,
label
=
torch
.
tensor
([
class_label
],
device
=
device
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
discrim_loss
=
ce_loss
(
prediction
,
label
)
discrim_loss
=
ce_loss
(
prediction
,
label
)
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
...
@@ -244,7 +243,8 @@ def perturb_past(
...
@@ -244,7 +243,8 @@ def perturb_past(
unpert_probs
+
SMALL_CONST
*
unpert_probs
+
SMALL_CONST
*
(
unpert_probs
<=
SMALL_CONST
).
float
().
to
(
device
).
detach
()
(
unpert_probs
<=
SMALL_CONST
).
float
().
to
(
device
).
detach
()
)
)
correction
=
SMALL_CONST
*
(
probs
<=
SMALL_CONST
).
float
().
to
(
device
).
detach
()
correction
=
SMALL_CONST
*
(
probs
<=
SMALL_CONST
).
float
().
to
(
device
).
detach
()
corrected_probs
=
probs
+
correction
.
detach
()
corrected_probs
=
probs
+
correction
.
detach
()
kl_loss
=
kl_scale
*
(
kl_loss
=
kl_scale
*
(
(
corrected_probs
*
(
corrected_probs
/
unpert_probs
).
log
()).
sum
()
(
corrected_probs
*
(
corrected_probs
/
unpert_probs
).
log
()).
sum
()
...
@@ -273,7 +273,8 @@ def perturb_past(
...
@@ -273,7 +273,8 @@ def perturb_past(
# normalize gradients
# normalize gradients
grad
=
[
grad
=
[
-
stepsize
*
-
stepsize
*
(
p_
.
grad
*
window_mask
/
grad_norms
[
index
]
**
gamma
).
data
.
cpu
().
numpy
()
(
p_
.
grad
*
window_mask
/
grad_norms
[
index
]
**
gamma
).
data
.
cpu
().
numpy
()
for
index
,
p_
in
enumerate
(
curr_perturbation
)
for
index
,
p_
in
enumerate
(
curr_perturbation
)
]
]
...
@@ -301,7 +302,7 @@ def perturb_past(
...
@@ -301,7 +302,7 @@ def perturb_past(
def
get_classifier
(
def
get_classifier
(
name
:
Optional
[
str
],
label_class
:
Union
[
str
,
int
],
name
:
Optional
[
str
],
class_label
:
Union
[
str
,
int
],
device
:
str
device
:
str
)
->
Tuple
[
Optional
[
ClassificationHead
],
Optional
[
int
]]:
)
->
Tuple
[
Optional
[
ClassificationHead
],
Optional
[
int
]]:
if
name
is
None
:
if
name
is
None
:
...
@@ -312,26 +313,29 @@ def get_classifier(
...
@@ -312,26 +313,29 @@ def get_classifier(
class_size
=
params
[
'class_size'
],
class_size
=
params
[
'class_size'
],
embed_size
=
params
[
'embed_size'
]
embed_size
=
params
[
'embed_size'
]
).
to
(
device
)
).
to
(
device
)
resolved_archive_file
=
cached_path
(
params
[
"url"
])
if
"url"
in
params
:
resolved_archive_file
=
cached_path
(
params
[
"url"
])
else
:
resolved_archive_file
=
params
[
"path"
]
classifier
.
load_state_dict
(
classifier
.
load_state_dict
(
torch
.
load
(
resolved_archive_file
,
map_location
=
device
))
torch
.
load
(
resolved_archive_file
,
map_location
=
device
))
classifier
.
eval
()
classifier
.
eval
()
if
isinstance
(
label_class
,
str
):
if
isinstance
(
class_label
,
str
):
if
label_class
in
params
[
"class_vocab"
]:
if
class_label
in
params
[
"class_vocab"
]:
label_id
=
params
[
"class_vocab"
][
label_class
]
label_id
=
params
[
"class_vocab"
][
class_label
]
else
:
else
:
label_id
=
params
[
"default_class"
]
label_id
=
params
[
"default_class"
]
print
(
"
label_class
{} not in class_vocab"
.
format
(
label_class
))
print
(
"
class_label
{} not in class_vocab"
.
format
(
class_label
))
print
(
"available values are: {}"
.
format
(
params
[
"class_vocab"
]))
print
(
"available values are: {}"
.
format
(
params
[
"class_vocab"
]))
print
(
"using default class {}"
.
format
(
label_id
))
print
(
"using default class {}"
.
format
(
label_id
))
elif
isinstance
(
label_class
,
int
):
elif
isinstance
(
class_label
,
int
):
if
label_class
in
set
(
params
[
"class_vocab"
].
values
()):
if
class_label
in
set
(
params
[
"class_vocab"
].
values
()):
label_id
=
label_class
label_id
=
class_label
else
:
else
:
label_id
=
params
[
"default_class"
]
label_id
=
params
[
"default_class"
]
print
(
"
label_class
{} not in class_vocab"
.
format
(
label_class
))
print
(
"
class_label
{} not in class_vocab"
.
format
(
class_label
))
print
(
"available values are: {}"
.
format
(
params
[
"class_vocab"
]))
print
(
"available values are: {}"
.
format
(
params
[
"class_vocab"
]))
print
(
"using default class {}"
.
format
(
label_id
))
print
(
"using default class {}"
.
format
(
label_id
))
...
@@ -379,7 +383,7 @@ def full_text_generation(
...
@@ -379,7 +383,7 @@ def full_text_generation(
device
=
"cuda"
,
device
=
"cuda"
,
sample
=
True
,
sample
=
True
,
discrim
=
None
,
discrim
=
None
,
label_class
=
None
,
class_label
=
None
,
bag_of_words
=
None
,
bag_of_words
=
None
,
length
=
100
,
length
=
100
,
grad_length
=
10000
,
grad_length
=
10000
,
...
@@ -397,7 +401,7 @@ def full_text_generation(
...
@@ -397,7 +401,7 @@ def full_text_generation(
):
):
classifier
,
class_id
=
get_classifier
(
classifier
,
class_id
=
get_classifier
(
discrim
,
discrim
,
label_class
,
class_label
,
device
device
)
)
...
@@ -443,7 +447,7 @@ def full_text_generation(
...
@@ -443,7 +447,7 @@ def full_text_generation(
perturb
=
True
,
perturb
=
True
,
bow_indices
=
bow_indices
,
bow_indices
=
bow_indices
,
classifier
=
classifier
,
classifier
=
classifier
,
label_class
=
class_id
,
class_label
=
class_id
,
loss_type
=
loss_type
,
loss_type
=
loss_type
,
length
=
length
,
length
=
length
,
grad_length
=
grad_length
,
grad_length
=
grad_length
,
...
@@ -477,7 +481,7 @@ def generate_text_pplm(
...
@@ -477,7 +481,7 @@ def generate_text_pplm(
sample
=
True
,
sample
=
True
,
perturb
=
True
,
perturb
=
True
,
classifier
=
None
,
classifier
=
None
,
label_class
=
None
,
class_label
=
None
,
bow_indices
=
None
,
bow_indices
=
None
,
loss_type
=
0
,
loss_type
=
0
,
length
=
100
,
length
=
100
,
...
@@ -545,7 +549,7 @@ def generate_text_pplm(
...
@@ -545,7 +549,7 @@ def generate_text_pplm(
grad_norms
=
grad_norms
,
grad_norms
=
grad_norms
,
stepsize
=
current_stepsize
,
stepsize
=
current_stepsize
,
classifier
=
classifier
,
classifier
=
classifier
,
label_
class
=
label
_
class
,
class
_
label
=
class
_label
,
one_hot_bows_vectors
=
one_hot_bows_vectors
,
one_hot_bows_vectors
=
one_hot_bows_vectors
,
loss_type
=
loss_type
,
loss_type
=
loss_type
,
num_iterations
=
num_iterations
,
num_iterations
=
num_iterations
,
...
@@ -567,7 +571,7 @@ def generate_text_pplm(
...
@@ -567,7 +571,7 @@ def generate_text_pplm(
if
classifier
is
not
None
:
if
classifier
is
not
None
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
prediction
=
classifier
(
torch
.
mean
(
unpert_last_hidden
,
dim
=
1
))
prediction
=
classifier
(
torch
.
mean
(
unpert_last_hidden
,
dim
=
1
))
label
=
torch
.
tensor
([
label_class
],
device
=
device
,
label
=
torch
.
tensor
([
class_label
],
device
=
device
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
unpert_discrim_loss
=
ce_loss
(
prediction
,
label
)
unpert_discrim_loss
=
ce_loss
(
prediction
,
label
)
print
(
print
(
...
@@ -613,6 +617,20 @@ def generate_text_pplm(
...
@@ -613,6 +617,20 @@ def generate_text_pplm(
return
output_so_far
,
unpert_discrim_loss
,
loss_in_time
return
output_so_far
,
unpert_discrim_loss
,
loss_in_time
def
set_generic_model_params
(
discrim_weights
,
discrim_meta
):
if
discrim_weights
is
None
:
raise
ValueError
(
'When using a generic discriminator, '
'discrim_weights need to be specified'
)
if
discrim_meta
is
None
:
raise
ValueError
(
'When using a generic discriminator, '
'discrim_meta need to be specified'
)
with
open
(
discrim_meta
,
'r'
)
as
discrim_meta_file
:
meta
=
json
.
load
(
discrim_meta_file
)
meta
[
'path'
]
=
discrim_weights
DISCRIMINATOR_MODELS_PARAMS
[
'generic'
]
=
meta
def
run_model
():
def
run_model
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -636,11 +654,15 @@ def run_model():
...
@@ -636,11 +654,15 @@ def run_model():
"-D"
,
"-D"
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
choices
=
(
"clickbait"
,
"sentiment"
,
"toxicity"
),
choices
=
(
"clickbait"
,
"sentiment"
,
"toxicity"
,
"generic"
),
help
=
"Discriminator to use
for loss-type 2
"
,
help
=
"Discriminator to use"
,
)
)
parser
.
add_argument
(
'--discrim_weights'
,
type
=
str
,
default
=
None
,
help
=
'Weights for the generic discriminator'
)
parser
.
add_argument
(
'--discrim_meta'
,
type
=
str
,
default
=
None
,
help
=
'Meta information for the generic discriminator'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
label_class
"
,
"--
class_label
"
,
type
=
int
,
type
=
int
,
default
=-
1
,
default
=-
1
,
help
=
"Class label used for the discriminator"
,
help
=
"Class label used for the discriminator"
,
...
@@ -697,6 +719,9 @@ def run_model():
...
@@ -697,6 +719,9 @@ def run_model():
# set the device
# set the device
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
if
args
.
discrim
==
'generic'
:
set_generic_model_params
(
args
.
discrim_weights
,
args
.
discrim_meta
)
# load pretrained model
# load pretrained model
model
=
GPT2LMHeadModel
.
from_pretrained
(
model
=
GPT2LMHeadModel
.
from_pretrained
(
args
.
model_path
,
args
.
model_path
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment