Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9f693a0c
Commit
9f693a0c
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Cleaned generate_text_pplm. Identical output as before.
parent
61a12f79
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
72 deletions
+53
-72
examples/run_pplm.py
examples/run_pplm.py
+53
-72
No files found.
examples/run_pplm.py
View file @
9f693a0c
...
...
@@ -471,59 +471,49 @@ def generate_text_pplm(
decay
=
False
,
gamma
=
1.5
,
):
output
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
output_so_far
=
(
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
)
# collect one hot vectors for bags of words
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
)
grad_norms
=
None
unpert_discrim_loss
=
0
loss_in_time
=
[]
for
i
in
trange
(
length
,
ascii
=
True
):
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current-token
# Therefore, use everything from before current i/p token to generate relevant past
if
past
is
None
and
output
is
not
None
:
prev
=
output
[:,
-
1
:]
# _, past = model(output[:, :-1])
# original_probs, true_past = model(output)
# true_hidden = model.hidden_states
# Piero modified model call
_
,
past
,
_
=
model
(
output
[:,
:
-
1
])
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output
)
true_hidden
=
unpert_all_hidden
[
-
1
]
else
:
# original_probs, true_past = model(output)
# true_hidden = model.hidden_states
# Note that GPT takes 2 inputs: past + current_token
# Piero modified model call
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output
)
true_hidden
=
unpert_all_hidden
[
-
1
]
# run model forward to obtain unperturbed
if
past
is
None
and
output_so_far
is
not
None
:
last
=
output_so_far
[:,
-
1
:]
_
,
past
,
_
=
model
(
output_so_far
[:,
:
-
1
])
# Modify the past if necessary
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output_so_far
)
unpert_last_hidden
=
unpert_all_hidden
[
-
1
]
# check if we are abowe grad max length
if
i
>=
grad_length
:
current_stepsize
=
stepsize
*
0
else
:
current_stepsize
=
stepsize
# modify the past if necessary
if
not
perturb
or
num_iterations
==
0
:
pert
urbed
_past
=
past
pert_past
=
past
else
:
# Piero modified model call
# accumulated_hidden = model.hidden_states[:, :-1, :]
accumulated_hidden
=
true_hidden
[:,
:
-
1
,
:]
accumulated_hidden
=
unpert_last_hidden
[:,
:
-
1
,
:]
accumulated_hidden
=
torch
.
sum
(
accumulated_hidden
,
dim
=
1
)
pert
urbed
_past
,
_
,
grad_norms
,
loss_
per
_iter
=
perturb_past
(
pert_past
,
_
,
grad_norms
,
loss_
this
_iter
=
perturb_past
(
past
,
model
,
prev
,
last
,
unpert_past
=
unpert_past
,
unpert_logits
=
unpert_logits
,
accumulated_hidden
=
accumulated_hidden
,
...
...
@@ -540,68 +530,59 @@ def generate_text_pplm(
decay
=
decay
,
gamma
=
gamma
,
)
loss_in_time
.
append
(
loss_
per
_iter
)
loss_in_time
.
append
(
loss_
this
_iter
)
# Piero modified model call
logits
,
past
,
pert_all_hidden
=
model
(
prev
,
past
=
perturbed_past
)
# test_logits = F.softmax(test_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(test_logits, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
pert_logits
,
past
,
pert_all_hidden
=
model
(
last
,
past
=
pert_past
)
pert_logits
=
pert_logits
[:,
-
1
,
:]
/
temperature
# + SMALL_CONST
pert_probs
=
F
.
softmax
(
pert_logits
,
dim
=-
1
)
if
classifier
is
not
None
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
predict
ed_sentiment
=
classifier
(
torch
.
mean
(
true
_hidden
,
dim
=
1
))
predict
ion
=
classifier
(
torch
.
mean
(
unpert_last
_hidden
,
dim
=
1
))
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
dtype
=
torch
.
long
)
true_discrim_loss
=
ce_loss
(
predicted_sentiment
,
label
)
print
(
"true discrim loss"
,
true_discrim_loss
.
data
.
cpu
().
numpy
())
unpert_discrim_loss
=
ce_loss
(
prediction
,
label
)
print
(
"unperturbed discrim loss"
,
unpert_discrim_loss
.
data
.
cpu
().
numpy
()
)
else
:
true_discrim_loss
=
0
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits
=
logits
[:,
-
1
,
:]
/
temperature
# + SMALL_CONST
# logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
unpert_discrim_loss
=
0
# Fuse the modified model and original model
if
perturb
:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
unpert_logits
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
unpert_probs
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
log_probs
=
((
log_probs
**
gm_scale
)
*
(
unpert_logits
**
(
1
-
gm_scale
)))
# + SMALL_CONST
log_probs
=
top_k_filter
(
log_probs
,
k
=
top_k
,
pert_probs
=
((
pert_probs
**
gm_scale
)
*
(
unpert_probs
**
(
1
-
gm_scale
)))
# + SMALL_CONST
pert_probs
=
top_k_filter
(
pert_probs
,
k
=
top_k
,
probs
=
True
)
# + SMALL_CONST
if
torch
.
sum
(
log_probs
)
<=
1
:
log_probs
=
log_probs
/
torch
.
sum
(
log_probs
)
# rescale
if
torch
.
sum
(
pert_probs
)
<=
1
:
pert_probs
=
pert_probs
/
torch
.
sum
(
pert_probs
)
else
:
logits
=
top_k_filter
(
logits
,
k
=
top_k
)
# + SMALL_CONST
log
_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
pert_
logits
=
top_k_filter
(
pert_
logits
,
k
=
top_k
)
# + SMALL_CONST
pert
_probs
=
F
.
softmax
(
pert_
logits
,
dim
=-
1
)
# sample or greedy
if
sample
:
# likelywords = torch.topk(log_probs, k=args.top_k, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
# print(likelywords[0].tolist())
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
)
last
=
torch
.
multinomial
(
pert_probs
,
num_samples
=
1
)
else
:
_
,
prev
=
torch
.
topk
(
log_probs
,
k
=
1
,
dim
=-
1
)
# if perturb:
# prev = future
output
=
prev
if
output
is
None
else
torch
.
cat
((
output
,
prev
),
dim
=
1
)
# update output
print
(
TOKENIZER
.
decode
(
output
.
tolist
()[
0
]))
return
output
,
true_discrim_loss
,
loss_in_time
_
,
last
=
torch
.
topk
(
pert_probs
,
k
=
1
,
dim
=-
1
)
# update context/output_so_far appending the new token
output_so_far
=
(
last
if
output_so_far
is
None
else
torch
.
cat
((
output_so_far
,
last
),
dim
=
1
)
)
print
(
TOKENIZER
.
decode
(
output_so_far
.
tolist
()[
0
]))
return
output_so_far
,
unpert_discrim_loss
,
loss_in_time
def
run_model
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment