Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9f693a0c
"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "502d5811505c806f95b92ae777388e9e6d3532fd"
Commit
9f693a0c
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Cleaned generate_text_pplm. Identical output as before.
parent
61a12f79
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
72 deletions
+53
-72
examples/run_pplm.py
examples/run_pplm.py
+53
-72
No files found.
examples/run_pplm.py
View file @
9f693a0c
...
...
@@ -471,59 +471,49 @@ def generate_text_pplm(
decay
=
False
,
gamma
=
1.5
,
):
output
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
output_so_far
=
(
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
)
# collect one hot vectors for bags of words
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
)
grad_norms
=
None
unpert_discrim_loss
=
0
loss_in_time
=
[]
for
i
in
trange
(
length
,
ascii
=
True
):
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current-token
# Therefore, use everything from before current i/p token to generate relevant past
if
past
is
None
and
output
is
not
None
:
prev
=
output
[:,
-
1
:]
# _, past = model(output[:, :-1])
# original_probs, true_past = model(output)
# true_hidden = model.hidden_states
# Piero modified model call
_
,
past
,
_
=
model
(
output
[:,
:
-
1
])
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output
)
true_hidden
=
unpert_all_hidden
[
-
1
]
else
:
# original_probs, true_past = model(output)
# true_hidden = model.hidden_states
# Note that GPT takes 2 inputs: past + current_token
# Piero modified model call
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output
)
true_hidden
=
unpert_all_hidden
[
-
1
]
# run model forward to obtain unperturbed
if
past
is
None
and
output_so_far
is
not
None
:
last
=
output_so_far
[:,
-
1
:]
_
,
past
,
_
=
model
(
output_so_far
[:,
:
-
1
])
# Modify the past if necessary
unpert_logits
,
unpert_past
,
unpert_all_hidden
=
model
(
output_so_far
)
unpert_last_hidden
=
unpert_all_hidden
[
-
1
]
# check if we are abowe grad max length
if
i
>=
grad_length
:
current_stepsize
=
stepsize
*
0
else
:
current_stepsize
=
stepsize
# modify the past if necessary
if
not
perturb
or
num_iterations
==
0
:
pert
urbed
_past
=
past
pert_past
=
past
else
:
# Piero modified model call
# accumulated_hidden = model.hidden_states[:, :-1, :]
accumulated_hidden
=
true_hidden
[:,
:
-
1
,
:]
accumulated_hidden
=
unpert_last_hidden
[:,
:
-
1
,
:]
accumulated_hidden
=
torch
.
sum
(
accumulated_hidden
,
dim
=
1
)
pert
urbed
_past
,
_
,
grad_norms
,
loss_
per
_iter
=
perturb_past
(
pert_past
,
_
,
grad_norms
,
loss_
this
_iter
=
perturb_past
(
past
,
model
,
prev
,
last
,
unpert_past
=
unpert_past
,
unpert_logits
=
unpert_logits
,
accumulated_hidden
=
accumulated_hidden
,
...
...
@@ -540,68 +530,59 @@ def generate_text_pplm(
decay
=
decay
,
gamma
=
gamma
,
)
loss_in_time
.
append
(
loss_
per
_iter
)
loss_in_time
.
append
(
loss_
this
_iter
)
# Piero modified model call
logits
,
past
,
pert_all_hidden
=
model
(
prev
,
past
=
perturbed_past
)
# test_logits = F.softmax(test_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(test_logits, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
pert_logits
,
past
,
pert_all_hidden
=
model
(
last
,
past
=
pert_past
)
pert_logits
=
pert_logits
[:,
-
1
,
:]
/
temperature
# + SMALL_CONST
pert_probs
=
F
.
softmax
(
pert_logits
,
dim
=-
1
)
if
classifier
is
not
None
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
predict
ed_sentiment
=
classifier
(
torch
.
mean
(
true
_hidden
,
dim
=
1
))
predict
ion
=
classifier
(
torch
.
mean
(
unpert_last
_hidden
,
dim
=
1
))
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
dtype
=
torch
.
long
)
true_discrim_loss
=
ce_loss
(
predicted_sentiment
,
label
)
print
(
"true discrim loss"
,
true_discrim_loss
.
data
.
cpu
().
numpy
())
unpert_discrim_loss
=
ce_loss
(
prediction
,
label
)
print
(
"unperturbed discrim loss"
,
unpert_discrim_loss
.
data
.
cpu
().
numpy
()
)
else
:
true_discrim_loss
=
0
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits
=
logits
[:,
-
1
,
:]
/
temperature
# + SMALL_CONST
# logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
unpert_discrim_loss
=
0
# Fuse the modified model and original model
if
perturb
:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
unpert_logits
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
unpert_probs
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
log_probs
=
((
log_probs
**
gm_scale
)
*
(
unpert_logits
**
(
1
-
gm_scale
)))
# + SMALL_CONST
log_probs
=
top_k_filter
(
log_probs
,
k
=
top_k
,
pert_probs
=
((
pert_probs
**
gm_scale
)
*
(
unpert_probs
**
(
1
-
gm_scale
)))
# + SMALL_CONST
pert_probs
=
top_k_filter
(
pert_probs
,
k
=
top_k
,
probs
=
True
)
# + SMALL_CONST
if
torch
.
sum
(
log_probs
)
<=
1
:
log_probs
=
log_probs
/
torch
.
sum
(
log_probs
)
# rescale
if
torch
.
sum
(
pert_probs
)
<=
1
:
pert_probs
=
pert_probs
/
torch
.
sum
(
pert_probs
)
else
:
logits
=
top_k_filter
(
logits
,
k
=
top_k
)
# + SMALL_CONST
log
_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
pert_
logits
=
top_k_filter
(
pert_
logits
,
k
=
top_k
)
# + SMALL_CONST
pert
_probs
=
F
.
softmax
(
pert_
logits
,
dim
=-
1
)
# sample or greedy
if
sample
:
# likelywords = torch.topk(log_probs, k=args.top_k, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
# print(likelywords[0].tolist())
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
)
last
=
torch
.
multinomial
(
pert_probs
,
num_samples
=
1
)
else
:
_
,
prev
=
torch
.
topk
(
log_probs
,
k
=
1
,
dim
=-
1
)
# if perturb:
# prev = future
output
=
prev
if
output
is
None
else
torch
.
cat
((
output
,
prev
),
dim
=
1
)
# update output
print
(
TOKENIZER
.
decode
(
output
.
tolist
()[
0
]))
return
output
,
true_discrim_loss
,
loss_in_time
_
,
last
=
torch
.
topk
(
pert_probs
,
k
=
1
,
dim
=-
1
)
# update context/output_so_far appending the new token
output_so_far
=
(
last
if
output_so_far
is
None
else
torch
.
cat
((
output_so_far
,
last
),
dim
=
1
)
)
print
(
TOKENIZER
.
decode
(
output_so_far
.
tolist
()[
0
]))
return
output_so_far
,
unpert_discrim_loss
,
loss_in_time
def
run_model
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment