Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a59fdd16
Commit
a59fdd16
authored
Dec 01, 2019
by
Piero Molino
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
generate_text_pplm now works with batch_size > 1
parent
893d0d64
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
6 deletions
+8
-6
examples/run_pplm.py
examples/run_pplm.py
+8
-6
No files found.
examples/run_pplm.py
View file @
a59fdd16
...
...
@@ -231,7 +231,8 @@ def perturb_past(
prediction
=
classifier
(
new_accumulated_hidden
/
(
curr_length
+
1
+
horizon_length
))
label
=
torch
.
tensor
([
class_label
],
device
=
device
,
label
=
torch
.
tensor
(
prediction
.
shape
[
0
]
*
[
class_label
],
device
=
device
,
dtype
=
torch
.
long
)
discrim_loss
=
ce_loss
(
prediction
,
label
)
print
(
" pplm_discrim_loss:"
,
discrim_loss
.
data
.
cpu
().
numpy
())
...
...
@@ -508,11 +509,12 @@ def generate_text_pplm(
gm_scale
=
0.9
,
kl_scale
=
0.01
,
):
output_so_far
=
(
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
if
context
else
None
)
output_so_far
=
None
if
context
:
context_t
=
torch
.
tensor
(
context
,
device
=
device
,
dtype
=
torch
.
long
)
while
len
(
context_t
.
shape
)
<
2
:
context_t
=
context_t
.
unsqueeze
(
0
)
output_so_far
=
context_t
# collect one hot vectors for bags of words
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
,
tokenizer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment