Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
61399e5a
Commit
61399e5a
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Cleaned perturb_past. Identical output as before.
parent
ffc29354
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
112 additions
and
92 deletions
+112
-92
examples/run_pplm.py
examples/run_pplm.py
+112
-92
No files found.
examples/run_pplm.py
View file @
61399e5a
...
...
@@ -112,7 +112,7 @@ def top_k_filter(logits, k, probs=False):
def
perturb_past
(
past
,
model
,
prev
,
last
,
unpert_past
=
None
,
unpert_logits
=
None
,
accumulated_hidden
=
None
,
...
...
@@ -128,156 +128,174 @@ def perturb_past(
horizon_length
=
1
,
decay
=
False
,
gamma
=
1.5
,
device
=
'cuda'
):
# Generate inital perturbed past
past_perturb_orig
=
[
(
np
.
random
.
uniform
(
0.0
,
0.0
,
p
.
shape
).
astype
(
'float32'
))
for
p
in
past
]
grad_accumulator
=
[
(
np
.
zeros
(
p
.
shape
).
astype
(
"float32"
))
for
p
in
past
]
if
accumulated_hidden
is
None
:
accumulated_hidden
=
0
if
decay
:
decay_mask
=
torch
.
arange
(
0.
,
1.0
+
SMALL_CONST
,
1.0
/
(
window_length
))[
1
:]
decay_mask
=
torch
.
arange
(
0.
,
1.0
+
SMALL_CONST
,
1.0
/
(
window_length
)
)[
1
:]
else
:
decay_mask
=
1.0
# TODO fix this comment (SUMANTH)
# Generate a mask is gradient perturbated is based on a past window
_
,
_
,
_
,
curr
ent
_length
,
_
=
past
[
0
].
shape
_
,
_
,
_
,
curr_length
,
_
=
past
[
0
].
shape
if
current_length
>
window_length
and
window_length
>
0
:
ones_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
(
[
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
if
curr_length
>
window_length
and
window_length
>
0
:
ones_key_val_shape
=
(
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
([
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
zeros_key_val_shape
=
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
(
[
current_length
-
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
zeros_key_val_shape
=
(
tuple
(
past
[
0
].
shape
[:
-
2
])
+
tuple
([
curr_length
-
window_length
])
+
tuple
(
past
[
0
].
shape
[
-
1
:])
)
ones_mask
=
torch
.
ones
(
ones_key_val_shape
)
ones_mask
=
decay_mask
*
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
ones_mask
=
ones_mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
window_mask
=
torch
.
cat
((
ones_mask
,
torch
.
zeros
(
zeros_key_val_shape
)),
dim
=-
2
).
cuda
()
window_mask
=
torch
.
cat
(
(
ones_mask
,
torch
.
zeros
(
zeros_key_val_shape
)),
dim
=-
2
).
to
(
device
)
else
:
window_mask
=
torch
.
ones_like
(
past
[
0
]).
cuda
(
)
window_mask
=
torch
.
ones_like
(
past
[
0
]).
to
(
device
)
# accumulate perturbations for num_iterations
loss_per_iter
=
[]
new_accumulated_hidden
=
None
for
i
in
range
(
num_iterations
):
print
(
"Iteration "
,
i
+
1
)
past_perturb
=
[
torch
.
from_numpy
(
p_
)
for
p_
in
past_perturb_orig
]
past_perturb
=
[
to_var
(
p_
,
requires_grad
=
True
)
for
p_
in
past_perturb
]
perturbed_past
=
list
(
map
(
add
,
past
,
past_perturb
))
_
,
_
,
_
,
current_length
,
_
=
past_perturb
[
0
].
shape
# _, future_past = model(prev, past=perturbed_past)
# hidden = model.hidden_states
# Piero modified model call
logits
,
_
,
all_hidden
=
model
(
prev
,
past
=
perturbed_past
)
curr_perturbation
=
[
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
for
p_
in
grad_accumulator
]
# Compute hidden using perturbed past
perturbed_past
=
list
(
map
(
add
,
past
,
curr_perturbation
))
_
,
_
,
_
,
curr_length
,
_
=
curr_perturbation
[
0
].
shape
all_logits
,
_
,
all_hidden
=
model
(
last
,
past
=
perturbed_past
)
hidden
=
all_hidden
[
-
1
]
new_accumulated_hidden
=
accumulated_hidden
+
torch
.
sum
(
hidden
,
dim
=
1
).
detach
()
new_accumulated_hidden
=
accumulated_hidden
+
torch
.
sum
(
hidden
,
dim
=
1
).
detach
()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits
=
all_logits
[:,
-
1
,
:]
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# TODO: Check the layer-norm consistency of this with trained discriminator
logits
=
logits
[:,
-
1
,
:]
probabs
=
F
.
softmax
(
logits
,
dim
=-
1
)
loss
=
0.0
loss_list
=
[]
if
loss_type
==
1
or
loss_type
==
3
:
for
one_hot_good
in
one_hot_bows_vectors
:
good_logits
=
torch
.
mm
(
probabs
,
torch
.
t
(
one_hot_good
))
loss_word
=
good_logits
loss_word
=
torch
.
sum
(
loss_word
)
loss_word
=
-
torch
.
log
(
loss_word
)
# loss_word = torch.sum(loss_word) /torch.sum(one_hot_good)
loss
+=
loss_word
loss_list
.
append
(
loss_word
)
if
loss_type
==
PPLM_BOW
or
loss_type
==
PPLM_BOW_DISCRIM
:
for
one_hot_bow
in
one_hot_bows_vectors
:
bow_logits
=
torch
.
mm
(
probs
,
torch
.
t
(
one_hot_bow
))
bow_loss
=
-
torch
.
log
(
torch
.
sum
(
bow_logits
))
loss
+=
bow_loss
loss_list
.
append
(
bow_loss
)
print
(
" pplm_bow_loss:"
,
loss
.
data
.
cpu
().
numpy
())
if
loss_type
==
2
or
loss_type
==
3
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
new_true_past
=
unpert_past
for
i
in
range
(
horizon_length
):
future_probabs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# Get softmax
future_probabs
=
torch
.
unsqueeze
(
future_probabs
,
dim
=
1
)
# _, new_true_past = model(future_probabs, past=new_true_past)
# future_hidden = model.hidden_states # Get expected hidden states
# Piero modified model call
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past
=
unpert_past
curr_probs
=
torch
.
unsqueeze
(
probs
,
dim
=
1
)
wte
=
model
.
resize_token_embeddings
()
inputs_embeds
=
torch
.
matmul
(
future_probabs
,
wte
.
weight
.
data
)
_
,
new_true_past
,
future_hidden
=
model
(
past
=
new_true_past
,
for
_
in
range
(
horizon_length
):
inputs_embeds
=
torch
.
matmul
(
curr_probs
,
wte
.
weight
.
data
)
_
,
curr_unpert_past
,
curr_all_hidden
=
model
(
past
=
curr_unpert_past
,
inputs_embeds
=
inputs_embeds
)
future_hidden
=
future_hidden
[
-
1
]
curr_hidden
=
curr_all_hidden
[
-
1
]
new_accumulated_hidden
=
new_accumulated_hidden
+
torch
.
sum
(
fut
ur
e
_hidden
,
dim
=
1
)
c
ur
r
_hidden
,
dim
=
1
)
predict
ed_sentiment
=
classifier
(
new_accumulated_hidden
/
(
curr
ent
_length
+
1
+
horizon_length
))
predict
ion
=
classifier
(
new_accumulated_hidden
/
(
curr_length
+
1
+
horizon_length
))
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
label
=
torch
.
tensor
([
label_class
],
device
=
device
,
dtype
=
torch
.
long
)
discrim_loss
=
ce_loss
(
predict
ed_sentiment
,
label
)
discrim_loss
=
ce_loss
(
predict
ion
,
label
)
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
loss
+=
discrim_loss
loss_list
.
append
(
discrim_loss
)
kl_loss
=
0.0
if
kl_scale
>
0.0
:
p
=
(
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
))
p
=
p
+
SMALL_CONST
*
(
p
<=
SMALL_CONST
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
correction
=
SMALL_CONST
*
(
probabs
<=
SMALL_CONST
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
corrected_probabs
=
probabs
+
correction
.
detach
()
unpert_probs
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
unpert_probs
=
(
unpert_probs
+
SMALL_CONST
*
(
unpert_probs
<=
SMALL_CONST
).
float
().
to
(
device
).
detach
()
)
correction
=
SMALL_CONST
*
(
probs
<=
SMALL_CONST
).
float
().
to
(
device
).
detach
()
corrected_probs
=
probs
+
correction
.
detach
()
kl_loss
=
kl_scale
*
(
(
corrected_probabs
*
(
corrected_probabs
/
p
).
log
()).
sum
())
print
(
' kl_loss'
,
(
kl_loss
).
data
.
cpu
().
numpy
())
loss
+=
kl_loss
# + discrim_loss
(
corrected_probs
*
(
corrected_probs
/
unpert_probs
).
log
()).
sum
()
)
print
(
' kl_loss'
,
kl_loss
.
data
.
cpu
().
numpy
())
loss
+=
kl_loss
loss_per_iter
.
append
(
loss
.
data
.
cpu
().
numpy
())
print
(
' pplm_loss'
,
(
loss
-
kl_loss
).
data
.
cpu
().
numpy
())
# compute gradients
loss
.
backward
()
if
grad_norms
is
not
None
and
loss_type
==
1
:
# calculate gradient norms
if
grad_norms
is
not
None
and
loss_type
==
PPLM_BOW
:
grad_norms
=
[
torch
.
max
(
grad_norms
[
index
],
torch
.
norm
(
p_
.
grad
*
window_mask
))
for
index
,
p_
in
enumerate
(
past_perturb
)
]
for
index
,
p_
in
enumerate
(
curr_perturbation
)
]
else
:
grad_norms
=
[(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
SMALL_CONST
)
for
index
,
p_
in
enumerate
(
past_perturb
)]
grad_norms
=
[
(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
SMALL_CONST
)
for
index
,
p_
in
enumerate
(
curr_perturbation
)
]
# normalize gradients
grad
=
[
-
stepsize
*
(
p_
.
grad
*
window_mask
/
grad_norms
[
index
]
**
gamma
).
data
.
cpu
().
numpy
()
for
index
,
p_
in
enumerate
(
past
_perturb
)
]
past_perturb_orig
=
list
(
map
(
add
,
grad
,
past_perturb_orig
))
-
stepsize
*
(
p_
.
grad
*
window_mask
/
grad_norms
[
index
]
**
gamma
).
data
.
cpu
().
numpy
()
for
index
,
p_
in
enumerate
(
curr
_perturb
ation
)
]
for
p_
in
past_perturb
:
# accumulate gradient
grad_accumulator
=
list
(
map
(
add
,
grad
,
grad_accumulator
))
# reset gradients, just to make sure
for
p_
in
curr_perturbation
:
p_
.
grad
.
data
.
zero_
()
# removing past from the graph
new_past
=
[]
for
p
in
past
:
new_past
.
append
(
p
.
detach
())
for
p_
in
past
:
new_past
.
append
(
p_
.
detach
())
past
=
new_past
past_perturb
=
[
torch
.
from_numpy
(
p_
)
for
p_
in
past_perturb_orig
]
past_perturb
=
[
to_var
(
p_
,
requires_grad
=
True
)
for
p_
in
past_perturb
]
perturbed_past
=
list
(
map
(
add
,
past
,
past_perturb
))
# apply the accumulated perturbations to the past
grad_accumulator
=
[
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
for
p_
in
grad_accumulator
]
pert_past
=
list
(
map
(
add
,
past
,
grad_accumulator
))
return
pert
urbed
_past
,
new_accumulated_hidden
,
grad_norms
,
loss_per_iter
return
pert_past
,
new_accumulated_hidden
,
grad_norms
,
loss_per_iter
def
get_classifier
(
...
...
@@ -532,6 +550,7 @@ def generate_text_pplm(
horizon_length
=
horizon_length
,
decay
=
decay
,
gamma
=
gamma
,
device
=
device
)
loss_in_time
.
append
(
loss_this_iter
)
else
:
...
...
@@ -662,7 +681,8 @@ def run_model():
parser
.
add_argument
(
"--decay"
,
action
=
"store_true"
,
help
=
"whether to decay or not"
)
parser
.
add_argument
(
"--gamma"
,
type
=
float
,
default
=
1.5
)
parser
.
add_argument
(
"--colorama"
,
action
=
"store_true"
,
help
=
"colors keywords"
)
parser
.
add_argument
(
"--colorama"
,
action
=
"store_true"
,
help
=
"colors keywords"
)
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment