Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9f693a0c
Commit
9f693a0c
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Cleaned generate_text_pplm. Identical output as before.
parent
61a12f79
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
72 deletions
+53
-72
examples/run_pplm.py
examples/run_pplm.py
+53
-72
No files found.
examples/run_pplm.py
View file @
9f693a0c
...
@@ -471,59 +471,49 @@ def generate_text_pplm(
...
@@ -471,59 +471,49 @@ def generate_text_pplm(
decay
=
False
,
decay
=
False
,
gamma
=
1.5
,
gamma
=
1.5
,
):
):
output
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
output_so_far
=
(
0
)
if
context
else
None
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
)
# collect one hot vectors for bags of words
# collect one hot vectors for bags of words
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
)
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
)
grad_norms
=
None
grad_norms
=
None
unpert_discrim_loss
=
0
loss_in_time
=
[]
loss_in_time
=
[]
for
i
in
trange
(
length
,
ascii
=
True
):
for
i
in
trange
(
length
,
ascii
=
True
):
# Get past/probs for current output, except for last word
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current-token
# Note that GPT takes 2 inputs: past + current_token
# Therefore, use everything from before current i/p token to generate relevant past
if
past
is
None
and
output
is
not
None
:
prev
=
output
[:,
-
1
:]
# _, past = model(output[:, :-1])
# original_probs, true_past = model(output)
# true_hidden = model.hidden_states
# Piero modified model call
_
,
past
,
_
=
model
(
output
[:,
:
-
1
])
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output
)
true_hidden
=
unpert_all_hidden
[
-
1
]
else
:
# run model forward to obtain unperturbed
# original_probs, true_past = model(output)
if
past
is
None
and
output_so_far
is
not
None
:
# true_hidden = model.hidden_states
last
=
output_so_far
[:,
-
1
:]
_
,
past
,
_
=
model
(
output_so_far
[:,
:
-
1
])
# Piero modified model call
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output_so_far
)
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output
)
unpert_last_hidden
=
unpert_all_hidden
[
-
1
]
true_hidden
=
unpert_all_hidden
[
-
1
]
# Modify the past if necessary
# check if we are abowe grad max length
if
i
>=
grad_length
:
if
i
>=
grad_length
:
current_stepsize
=
stepsize
*
0
current_stepsize
=
stepsize
*
0
else
:
else
:
current_stepsize
=
stepsize
current_stepsize
=
stepsize
# modify the past if necessary
if
not
perturb
or
num_iterations
==
0
:
if
not
perturb
or
num_iterations
==
0
:
pert
urbed
_past
=
past
pert_past
=
past
else
:
else
:
# Piero modified model call
accumulated_hidden
=
unpert_last_hidden
[:,
:
-
1
,
:]
# accumulated_hidden = model.hidden_states[:, :-1, :]
accumulated_hidden
=
true_hidden
[:,
:
-
1
,
:]
accumulated_hidden
=
torch
.
sum
(
accumulated_hidden
,
dim
=
1
)
accumulated_hidden
=
torch
.
sum
(
accumulated_hidden
,
dim
=
1
)
pert
urbed
_past
,
_
,
grad_norms
,
loss_
per
_iter
=
perturb_past
(
pert_past
,
_
,
grad_norms
,
loss_
this
_iter
=
perturb_past
(
past
,
past
,
model
,
model
,
prev
,
last
,
unpert_past
=
unpert_past
,
unpert_past
=
unpert_past
,
unpert_logits
=
unpert_logits
,
unpert_logits
=
unpert_logits
,
accumulated_hidden
=
accumulated_hidden
,
accumulated_hidden
=
accumulated_hidden
,
...
@@ -540,68 +530,59 @@ def generate_text_pplm(
...
@@ -540,68 +530,59 @@ def generate_text_pplm(
decay
=
decay
,
decay
=
decay
,
gamma
=
gamma
,
gamma
=
gamma
,
)
)
loss_in_time
.
append
(
loss_
per
_iter
)
loss_in_time
.
append
(
loss_
this
_iter
)
# Piero modified model call
pert_logits
,
past
,
pert_all_hidden
=
model
(
last
,
past
=
pert_past
)
logits
,
past
,
pert_all_hidden
=
model
(
prev
,
past
=
perturbed_past
)
pert_logits
=
pert_logits
[:,
-
1
,
:]
/
temperature
# + SMALL_CONST
# test_logits = F.softmax(test_logits[:, -1, :], dim=-1)
pert_probs
=
F
.
softmax
(
pert_logits
,
dim
=-
1
)
# likelywords = torch.topk(test_logits, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
if
classifier
is
not
None
:
if
classifier
is
not
None
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
predict
ed_sentiment
=
classifier
(
torch
.
mean
(
true
_hidden
,
dim
=
1
))
predict
ion
=
classifier
(
torch
.
mean
(
unpert_last
_hidden
,
dim
=
1
))
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
true_discrim_loss
=
ce_loss
(
predicted_sentiment
,
label
)
unpert_discrim_loss
=
ce_loss
(
prediction
,
label
)
print
(
"true discrim loss"
,
true_discrim_loss
.
data
.
cpu
().
numpy
())
print
(
"unperturbed discrim loss"
,
unpert_discrim_loss
.
data
.
cpu
().
numpy
()
)
else
:
else
:
true_discrim_loss
=
0
unpert_discrim_loss
=
0
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits
=
logits
[:,
-
1
,
:]
/
temperature
# + SMALL_CONST
# logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# Fuse the modified model and original model
# Fuse the modified model and original model
if
perturb
:
if
perturb
:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
unpert_probs
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
unpert_logits
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
log_probs
=
((
log_probs
**
gm_scale
)
*
(
unpert_logits
**
(
1
-
gm_scale
)))
# + SMALL_CONST
log_probs
=
top_k_filter
(
log_probs
,
k
=
top_k
,
pert_probs
=
((
pert_probs
**
gm_scale
)
*
(
unpert_probs
**
(
1
-
gm_scale
)))
# + SMALL_CONST
pert_probs
=
top_k_filter
(
pert_probs
,
k
=
top_k
,
probs
=
True
)
# + SMALL_CONST
probs
=
True
)
# + SMALL_CONST
if
torch
.
sum
(
log_probs
)
<=
1
:
# rescale
log_probs
=
log_probs
/
torch
.
sum
(
log_probs
)
if
torch
.
sum
(
pert_probs
)
<=
1
:
pert_probs
=
pert_probs
/
torch
.
sum
(
pert_probs
)
else
:
else
:
logits
=
top_k_filter
(
logits
,
k
=
top_k
)
# + SMALL_CONST
pert_
logits
=
top_k_filter
(
pert_
logits
,
k
=
top_k
)
# + SMALL_CONST
log
_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
pert
_probs
=
F
.
softmax
(
pert_
logits
,
dim
=-
1
)
# sample or greedy
if
sample
:
if
sample
:
# likelywords = torch.topk(log_probs, k=args.top_k, dim=-1)
last
=
torch
.
multinomial
(
pert_probs
,
num_samples
=
1
)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
# print(likelywords[0].tolist())
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
)
else
:
else
:
_
,
prev
=
torch
.
topk
(
log_probs
,
k
=
1
,
dim
=-
1
)
_
,
last
=
torch
.
topk
(
pert_probs
,
k
=
1
,
dim
=-
1
)
# if perturb:
# prev = future
# update context/output_so_far appending the new token
output
=
prev
if
output
is
None
else
torch
.
cat
((
output
,
prev
),
output_so_far
=
(
dim
=
1
)
# update output
last
if
output_so_far
is
None
print
(
TOKENIZER
.
decode
(
output
.
tolist
()[
0
]))
else
torch
.
cat
((
output_so_far
,
last
),
dim
=
1
)
)
return
output
,
true_discrim_loss
,
loss_in_time
print
(
TOKENIZER
.
decode
(
output_so_far
.
tolist
()[
0
]))
return
output_so_far
,
unpert_discrim_loss
,
loss_in_time
def
run_model
():
def
run_model
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment