Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
61a12f79
Commit
61a12f79
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Renamed SmallConst to SMALL_CONST and introduced BIG_CONST. Identical output as before.
parent
ef47b2c0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
12 deletions
+13
-12
examples/run_pplm.py
examples/run_pplm.py
+13
-12
No files found.
examples/run_pplm.py
View file @
61a12f79
...
...
@@ -43,7 +43,7 @@ PPLM_BOW = 1
PPLM_DISCRIM
=
2
PPLM_BOW_DISCRIM
=
3
SMALL_CONST
=
1e-15
SmallConst
=
1e
-15
BIG_CONST
=
1e
10
TOKENIZER
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2-medium"
)
BAG_OF_WORDS_ARCHIVE_MAP
=
{
...
...
@@ -104,7 +104,8 @@ def top_k_filter(logits, k, probs=False):
if
probs
:
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
0.0
,
logits
)
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
-
1e10
,
return
torch
.
where
(
logits
<
batch_mins
,
torch
.
ones_like
(
logits
)
*
-
BIG_CONST
,
logits
)
...
...
@@ -137,7 +138,7 @@ def perturb_past(
accumulated_hidden
=
0
if
decay
:
decay_mask
=
torch
.
arange
(
0.
,
1.0
+
S
mallConst
,
1.0
/
(
window_length
))[
decay_mask
=
torch
.
arange
(
0.
,
1.0
+
S
MALL_CONST
,
1.0
/
(
window_length
))[
1
:]
else
:
decay_mask
=
1.0
...
...
@@ -233,9 +234,9 @@ def perturb_past(
kl_loss
=
0.0
if
kl_scale
>
0.0
:
p
=
(
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
))
p
=
p
+
S
mallConst
*
(
p
<=
SmallConst
).
type
(
p
=
p
+
S
MALL_CONST
*
(
p
<=
SMALL_CONST
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
correction
=
S
mallConst
*
(
probabs
<=
S
mallConst
).
type
(
correction
=
S
MALL_CONST
*
(
probabs
<=
S
MALL_CONST
).
type
(
torch
.
FloatTensor
).
cuda
().
detach
()
corrected_probabs
=
probabs
+
correction
.
detach
()
kl_loss
=
kl_scale
*
(
...
...
@@ -254,7 +255,7 @@ def perturb_past(
for
index
,
p_
in
enumerate
(
past_perturb
)]
else
:
grad_norms
=
[(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
S
mallConst
)
for
grad_norms
=
[(
torch
.
norm
(
p_
.
grad
*
window_mask
)
+
S
MALL_CONST
)
for
index
,
p_
in
enumerate
(
past_perturb
)]
grad
=
[
...
...
@@ -560,31 +561,31 @@ def generate_text_pplm(
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits
=
logits
[:,
-
1
,
:]
/
temperature
# + S
mallConst
logits
=
logits
[:,
-
1
,
:]
/
temperature
# + S
MALL_CONST
# logits = top_k_filter(logits, k=args.top_k) # + S
mallConst
# logits = top_k_filter(logits, k=args.top_k) # + S
MALL_CONST
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
# Fuse the modified model and original model
if
perturb
:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ S
mallConst
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ S
MALL_CONST
unpert_logits
=
F
.
softmax
(
unpert_logits
[:,
-
1
,
:],
dim
=-
1
)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
log_probs
=
((
log_probs
**
gm_scale
)
*
(
unpert_logits
**
(
1
-
gm_scale
)))
# + S
mallConst
unpert_logits
**
(
1
-
gm_scale
)))
# + S
MALL_CONST
log_probs
=
top_k_filter
(
log_probs
,
k
=
top_k
,
probs
=
True
)
# + S
mallConst
probs
=
True
)
# + S
MALL_CONST
if
torch
.
sum
(
log_probs
)
<=
1
:
log_probs
=
log_probs
/
torch
.
sum
(
log_probs
)
else
:
logits
=
top_k_filter
(
logits
,
k
=
top_k
)
# + S
mallConst
logits
=
top_k_filter
(
logits
,
k
=
top_k
)
# + S
MALL_CONST
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
if
sample
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment