Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
61399e5a
Commit
61399e5a
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Cleaned perturb_past. Identical output as before.
parent
ffc29354
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
112 additions
and
92 deletions
+112
-92
examples/run_pplm.py
examples/run_pplm.py
+112
-92
No files found.
examples/run_pplm.py
View file @
61399e5a
...
@@ -112,7 +112,7 @@ def top_k_filter(logits, k, probs=False):
...
@@ -112,7 +112,7 @@ def top_k_filter(logits, k, probs=False):
def
perturb_past
(
def
perturb_past
(
past
,
past
,
model
,
model
,
prev
,
last
,
unpert_past
=
None
,
unpert_past
=
None
,
unpert_logits
=
None
,
unpert_logits
=
None
,
accumulated_hidden
=
None
,
accumulated_hidden
=
None
,
...
@@ -128,156 +128,174 @@ def perturb_past(
...
@@ -128,156 +128,174 @@ def perturb_past(
horizon_length
=
1
,
horizon_length
=
1
,
decay
=
False
,
decay
=
False
,
gamma
=
1.5
,
gamma
=
1.5
,
device
=
'cuda'
):
):
# Generate inital perturbed past
# Generate inital perturbed past
past_perturb_orig
=
[
grad_accumulator
=
[
(
np
.
random
.
uniform
(
0.0
,
0.0
,
p
.
shape
).
astype
(
'float32'
))
(
np
.
zeros
(
p
.
shape
).
astype
(
"float32"
))
for
p
in
past
]
for
p
in
past
]
if
accumulated_hidden
is
None
:
if
accumulated_hidden
is
None
:
accumulated_hidden
=
0
accumulated_hidden
=
0
if
decay
:
if
decay
:
decay_mask
=
torch
.
arange
(
0.
,
1.0
+
SMALL_CONST
,
1.0
/
(
window_length
))[
decay_mask
=
torch
.
arange
(
1
:]
0.
,
1.0
+
SMALL_CONST
,
1.0
/
(
window_length
)
)[
1
:]
else
:
else
:
decay_mask
=
1.0
decay_mask
=
1.0
# TODO fix this comment (SUMANTH)
# Generate a mask is gradient perturbated is based on a past window
# Generate a mask is gradient perturbated is based on a past window
_
,
_
,
_
,
curr
ent
_length
,
_
=
past
[
0
].
shape
_
,
_
,
_
,
curr_length
,
_
=
past
[
0
].
shape
if
current_length
>
window_length
and
window_length
>
0
:
if
curr_length
>
window_length
and
window_length
>
0
:
ones_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
(
ones_key_val_shape
=
(
[
window_length
])
+
tuple
(
tuple
(
past
[
0
].
shape
[:
-
2
])
past
[
0
].
shape
[
-
1
:])
+
tuple
([
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
zeros_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
(
zeros_key_val_shape
=
(
[
current_length
-
window_length
])
+
tuple
(
tuple
(
past
[
0
].
shape
[:
-
2
])
past
[
0
].
shape
[
-
1
:])
+
tuple
([
curr_length
-
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
ones_mask
=
torch
.
ones
(
ones_key_val_shape
)
ones_mask
=
torch
.
ones
(
ones_key_val_shape
)
ones_mask
=
decay_mask
*
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
ones_mask
=
decay_mask
*
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
ones_mask
=
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
ones_mask
=
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
window_mask
=
torch
.
cat
((
ones_mask
,
torch
.
zeros
(
zeros_key_val_shape
)),
window_mask
=
torch
.
cat
(
dim
=-
2
).
cuda
()
(
ones_mask
,
torch
.
zeros
(
zeros_key_val_shape
)),
dim
=-
2
).
to
(
device
)
else
:
else
:
window_mask
=
torch
.
ones_like
(
past
[
0
]).
cuda
(
)
window_mask
=
torch
.
ones_like
(
past
[
0
]).
to
(
device
)
# accumulate perturbations for num_iterations
loss_per_iter
=
[]
loss_per_iter
=
[]
new_accumulated_hidden
=
None
for
i
in
range
(
num_iterations
):
for
i
in
range
(
num_iterations
):
print
(
"Iteration "
,
i
+
1
)
print
(
"Iteration "
,
i
+
1
)
past_perturb
=
[
torch
.
from_numpy
(
p_
)
for
p_
in
past_perturb_orig
]
curr_perturbation
=
[
past_perturb
=
[
to_var
(
p_
,
requires_grad
=
True
)
for
p_
in
past_perturb
]
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
for
p_
in
grad_accumulator
perturbed_past
=
list
(
map
(
add
,
past
,
past_perturb
))
]
_
,
_
,
_
,
current_length
,
_
=
past_perturb
[
0
].
shape
# Compute hidden using perturbed past
perturbed_past
=
list
(
map
(
add
,
past
,
curr_perturbation
))
# _, future_past = model(prev, past=perturbed_past)
_
,
_
,
_
,
curr_length
,
_
=
curr_perturbation
[
0
].
shape
# hidden = model.hidden_states
all_logits
,
_
,
all_hidden
=
model
(
last
,
past
=
perturbed_past
)
# Piero modified model call
logits
,
_
,
all_hidden
=
model
(
prev
,
past
=
perturbed_past
)
hidden
=
all_hidden
[
-
1
]
hidden
=
all_hidden
[
-
1
]
new_accumulated_hidden
=
accumulated_hidden
+
torch
.
sum
(
hidden
,
new_accumulated_hidden
=
accumulated_hidden
+
torch
.
sum
(
dim
=
1
).
detach
()
hidden
,
dim
=
1
).
detach
()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits
=
all_logits
[:,
-
1
,
:]
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# TODO: Check the layer-norm consistency of this with trained discriminator
logits
=
logits
[:,
-
1
,
:]
probabs
=
F
.
softmax
(
logits
,
dim
=-
1
)
loss
=
0.0
loss
=
0.0
loss_list
=
[]
loss_list
=
[]
if
loss_type
==
1
or
loss_type
==
3
:
if
loss_type
==
PPLM_BOW
or
loss_type
==
PPLM_BOW_DISCRIM
:
for
one_hot_good
in
one_hot_bows_vectors
:
for
one_hot_bow
in
one_hot_bows_vectors
:
good_logits
=
torch
.
mm
(
probabs
,
torch
.
t
(
one_hot_good
))
bow_logits
=
torch
.
mm
(
probs
,
torch
.
t
(
one_hot_bow
))
loss_word
=
good_logits
bow_loss
=
-
torch
.
log
(
torch
.
sum
(
bow_logits
))
loss_word
=
torch
.
sum
(
loss_word
)
loss
+=
bow_loss
loss_word
=
-
torch
.
log
(
loss_word
)
loss_list
.
append
(
bow_loss
)
# loss_word = torch.sum(loss_word) /torch.sum(one_hot_good)
loss
+=
loss_word
loss_list
.
append
(
loss_word
)
print
(
" pplm_bow_loss:"
,
loss
.
data
.
cpu
().
numpy
())
print
(
" pplm_bow_loss:"
,
loss
.
data
.
cpu
().
numpy
())
if
loss_type
==
2
or
loss_type
==
3
:
if
loss_type
==
2
or
loss_type
==
3
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
new_true_past
=
unpert_past
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
for
i
in
range
(
horizon_length
):
curr_unpert_past
=
unpert_past
future_probabs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# Get softmax
curr_probs
=
torch
.
unsqueeze
(
probs
,
dim
=
1
)
future_probabs
=
torch
.
unsqueeze
(
future_probabs
,
dim
=
1
)
wte
=
model
.
resize_token_embeddings
()
for
_
in
range
(
horizon_length
):
# _, new_true_past = model(future_probabs, past=new_true_past)
inputs_embeds
=
torch
.
matmul
(
curr_probs
,
wte
.
weight
.
data
)
# future_hidden = model.hidden_states # Get expected hidden states
_
,
curr_unpert_past
,
curr_all_hidden
=
model
(
past
=
curr_unpert_past
,
# Piero modified model call
wte
=
model
.
resize_token_embeddings
()
inputs_embeds
=
torch
.
matmul
(
future_probabs
,
wte
.
weight
.
data
)
_
,
new_true_past
,
future_hidden
=
model
(
past
=
new_true_past
,
inputs_embeds
=
inputs_embeds
inputs_embeds
=
inputs_embeds
)
)
future_hidden
=
future_hidden
[
-
1
]
curr_hidden
=
curr_all_hidden
[
-
1
]
new_accumulated_hidden
=
new_accumulated_hidden
+
torch
.
sum
(
new_accumulated_hidden
=
new_accumulated_hidden
+
torch
.
sum
(
fut
ur
e
_hidden
,
dim
=
1
)
c
ur
r
_hidden
,
dim
=
1
)
predict
ed_sentiment
=
classifier
(
new_accumulated_hidden
/
(
predict
ion
=
classifier
(
new_accumulated_hidden
/
curr
ent
_length
+
1
+
horizon_length
))
(
curr_length
+
1
+
horizon_length
))
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
label
=
torch
.
tensor
([
label_class
],
device
=
device
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
discrim_loss
=
ce_loss
(
predict
ed_sentiment
,
label
)
discrim_loss
=
ce_loss
(
predict
ion
,
label
)
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
loss
+=
discrim_loss
loss
+=
discrim_loss
loss_list
.
append
(
discrim_loss
)
loss_list
.
append
(
discrim_loss
)
kl_loss
=
0.0
kl_loss
=
0.0
if
kl_scale
>
0.0
:
if
kl_scale
>
0.0
:
p
=
(
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
))
unpert_probs
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
p
=
p
+
SMALL_CONST
*
(
p
<=
SMALL_CONST
).
type
(
unpert_probs
=
(
torch
.
FloatTensor
).
cuda
().
detach
()
unpert_probs
+
SMALL_CONST
*
correction
=
SMALL_CONST
*
(
probabs
<=
SMALL_CONST
).
type
(
(
unpert_probs
<=
SMALL_CONST
).
float
().
to
(
device
).
detach
()
torch
.
FloatTensor
).
cuda
().
detach
()
)
corrected_probabs
=
probabs
+
correction
.
detach
()
correction
=
SMALL_CONST
*
(
probs
<=
SMALL_CONST
).
float
().
to
(
device
).
detach
()
corrected_probs
=
probs
+
correction
.
detach
()
kl_loss
=
kl_scale
*
(
kl_loss
=
kl_scale
*
(
(
corrected_probabs
*
(
corrected_probabs
/
p
).
log
()).
sum
())
(
corrected_probs
*
(
corrected_probs
/
unpert_probs
).
log
()).
sum
()
print
(
' kl_loss'
,
(
kl_loss
).
data
.
cpu
().
numpy
())
)
loss
+=
kl_loss
# + discrim_loss
print
(
' kl_loss'
,
kl_loss
.
data
.
cpu
().
numpy
())
loss
+=
kl_loss
loss_per_iter
.
append
(
loss
.
data
.
cpu
().
numpy
())
loss_per_iter
.
append
(
loss
.
data
.
cpu
().
numpy
())
print
(
' pplm_loss'
,
(
loss
-
kl_loss
).
data
.
cpu
().
numpy
())
print
(
' pplm_loss'
,
(
loss
-
kl_loss
).
data
.
cpu
().
numpy
())
# compute gradients
loss
.
backward
()
loss
.
backward
()
if
grad_norms
is
not
None
and
loss_type
==
1
:
# calculate gradient norms
if
grad_norms
is
not
None
and
loss_type
==
PPLM_BOW
:
grad_norms
=
[
grad_norms
=
[
torch
.
max
(
grad_norms
[
index
],
torch
.
norm
(
p_
.
grad
*
window_mask
))
torch
.
max
(
grad_norms
[
index
],
torch
.
norm
(
p_
.
grad
*
window_mask
))
for
index
,
p_
in
for
index
,
p_
in
enumerate
(
curr_perturbation
)
enumerate
(
past_perturb
)
]
]
else
:
else
:
grad_norms
=
[(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
SMALL_CONST
)
for
grad_norms
=
[
index
,
p_
in
enumerate
(
past_perturb
)]
(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
SMALL_CONST
)
for
index
,
p_
in
enumerate
(
curr_perturbation
)
]
# normalize gradients
grad
=
[
grad
=
[
-
stepsize
*
(
p_
.
grad
*
window_mask
/
grad_norms
[
-
stepsize
*
index
]
**
gamma
).
data
.
cpu
().
numpy
()
(
p_
.
grad
*
window_mask
/
grad_norms
[
index
]
**
gamma
).
data
.
cpu
().
numpy
()
for
index
,
p_
in
enumerate
(
past
_perturb
)
]
for
index
,
p_
in
enumerate
(
curr
_perturb
ation
)
past_perturb_orig
=
list
(
map
(
add
,
grad
,
past_perturb_orig
))
]
for
p_
in
past_perturb
:
# accumulate gradient
grad_accumulator
=
list
(
map
(
add
,
grad
,
grad_accumulator
))
# reset gradients, just to make sure
for
p_
in
curr_perturbation
:
p_
.
grad
.
data
.
zero_
()
p_
.
grad
.
data
.
zero_
()
# removing past from the graph
new_past
=
[]
new_past
=
[]
for
p
in
past
:
for
p_
in
past
:
new_past
.
append
(
p
.
detach
())
new_past
.
append
(
p_
.
detach
())
past
=
new_past
past
=
new_past
past_perturb
=
[
torch
.
from_numpy
(
p_
)
for
p_
in
past_perturb_orig
]
# apply the accumulated perturbations to the past
past_perturb
=
[
to_var
(
p_
,
requires_grad
=
True
)
for
p_
in
past_perturb
]
grad_accumulator
=
[
perturbed_past
=
list
(
map
(
add
,
past
,
past_perturb
))
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
for
p_
in
grad_accumulator
]
pert_past
=
list
(
map
(
add
,
past
,
grad_accumulator
))
return
pert
urbed
_past
,
new_accumulated_hidden
,
grad_norms
,
loss_per_iter
return
pert_past
,
new_accumulated_hidden
,
grad_norms
,
loss_per_iter
def
get_classifier
(
def
get_classifier
(
...
@@ -532,6 +550,7 @@ def generate_text_pplm(
...
@@ -532,6 +550,7 @@ def generate_text_pplm(
horizon_length
=
horizon_length
,
horizon_length
=
horizon_length
,
decay
=
decay
,
decay
=
decay
,
gamma
=
gamma
,
gamma
=
gamma
,
device
=
device
)
)
loss_in_time
.
append
(
loss_this_iter
)
loss_in_time
.
append
(
loss_this_iter
)
else
:
else
:
...
@@ -562,7 +581,7 @@ def generate_text_pplm(
...
@@ -562,7 +581,7 @@ def generate_text_pplm(
pert_probs
=
((
pert_probs
**
gm_scale
)
*
(
pert_probs
=
((
pert_probs
**
gm_scale
)
*
(
unpert_probs
**
(
1
-
gm_scale
)))
# + SMALL_CONST
unpert_probs
**
(
1
-
gm_scale
)))
# + SMALL_CONST
pert_probs
=
top_k_filter
(
pert_probs
,
k
=
top_k
,
pert_probs
=
top_k_filter
(
pert_probs
,
k
=
top_k
,
probs
=
True
)
# + SMALL_CONST
probs
=
True
)
# + SMALL_CONST
# rescale
# rescale
if
torch
.
sum
(
pert_probs
)
<=
1
:
if
torch
.
sum
(
pert_probs
)
<=
1
:
...
@@ -662,7 +681,8 @@ def run_model():
...
@@ -662,7 +681,8 @@ def run_model():
parser
.
add_argument
(
"--decay"
,
action
=
"store_true"
,
parser
.
add_argument
(
"--decay"
,
action
=
"store_true"
,
help
=
"whether to decay or not"
)
help
=
"whether to decay or not"
)
parser
.
add_argument
(
"--gamma"
,
type
=
float
,
default
=
1.5
)
parser
.
add_argument
(
"--gamma"
,
type
=
float
,
default
=
1.5
)
parser
.
add_argument
(
"--colorama"
,
action
=
"store_true"
,
help
=
"colors keywords"
)
parser
.
add_argument
(
"--colorama"
,
action
=
"store_true"
,
help
=
"colors keywords"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment