Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
afc7dcd9
Commit
afc7dcd9
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Now run_pplm works on cpu. Identical output as before (when using gpu).
parent
61399e5a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
12 deletions
+16
-12
examples/run_pplm.py
examples/run_pplm.py
+16
-12
No files found.
examples/run_pplm.py
View file @
afc7dcd9
...
...
@@ -84,9 +84,11 @@ DISCRIMINATOR_MODELS_PARAMS = {
}
def
to_var
(
x
,
requires_grad
=
False
,
volatile
=
False
):
if
torch
.
cuda
.
is_available
():
def
to_var
(
x
,
requires_grad
=
False
,
volatile
=
False
,
device
=
'cuda'
):
if
torch
.
cuda
.
is_available
()
and
device
==
'cuda'
:
x
=
x
.
cuda
()
elif
device
!=
'cuda'
:
x
=
x
.
to
(
device
)
return
Variable
(
x
,
requires_grad
=
requires_grad
,
volatile
=
volatile
)
...
...
@@ -182,7 +184,7 @@ def perturb_past(
for
i
in
range
(
num_iterations
):
print
(
"Iteration "
,
i
+
1
)
curr_perturbation
=
[
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
,
device
=
device
)
for
p_
in
grad_accumulator
]
...
...
@@ -290,7 +292,7 @@ def perturb_past(
# apply the accumulated perturbations to the past
grad_accumulator
=
[
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
,
device
=
device
)
for
p_
in
grad_accumulator
]
pert_past
=
list
(
map
(
add
,
past
,
grad_accumulator
))
...
...
@@ -300,7 +302,7 @@ def perturb_past(
def
get_classifier
(
name
:
Optional
[
str
],
label_class
:
Union
[
str
,
int
],
device
:
Union
[
str
,
torch
.
device
]
device
:
str
)
->
Tuple
[
Optional
[
ClassificationHead
],
Optional
[
int
]]:
if
name
is
None
:
return
None
,
None
...
...
@@ -355,16 +357,16 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
return
bow_indices
def
build_bows_one_hot_vectors
(
bow_indices
):
def
build_bows_one_hot_vectors
(
bow_indices
,
device
=
'cuda'
):
if
bow_indices
is
None
:
return
None
one_hot_bows_vectors
=
[]
for
single_bow
in
bow_indices
:
single_bow
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
single_bow
))
single_bow
=
torch
.
tensor
(
single_bow
).
cuda
(
)
single_bow
=
torch
.
tensor
(
single_bow
).
to
(
device
)
num_words
=
single_bow
.
shape
[
0
]
one_hot_bow
=
torch
.
zeros
(
num_words
,
TOKENIZER
.
vocab_size
).
cuda
(
)
one_hot_bow
=
torch
.
zeros
(
num_words
,
TOKENIZER
.
vocab_size
).
to
(
device
)
one_hot_bow
.
scatter_
(
1
,
single_bow
,
1
)
one_hot_bows_vectors
.
append
(
one_hot_bow
)
return
one_hot_bows_vectors
...
...
@@ -425,7 +427,8 @@ def full_text_generation(
length
=
length
,
perturb
=
False
)
torch
.
cuda
.
empty_cache
()
if
device
==
'cuda'
:
torch
.
cuda
.
empty_cache
()
pert_gen_tok_texts
=
[]
discrim_losses
=
[]
...
...
@@ -460,7 +463,8 @@ def full_text_generation(
discrim_losses
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
losses_in_time
.
append
(
loss_in_time
)
torch
.
cuda
.
empty_cache
()
if
device
==
'cuda'
:
torch
.
cuda
.
empty_cache
()
return
unpert_gen_tok_text
,
pert_gen_tok_texts
,
discrim_losses
,
losses_in_time
...
...
@@ -496,7 +500,7 @@ def generate_text_pplm(
)
# collect one hot vectors for bags of words
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
)
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
,
device
)
grad_norms
=
None
last
=
None
...
...
@@ -563,7 +567,7 @@ def generate_text_pplm(
if
classifier
is
not
None
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
prediction
=
classifier
(
torch
.
mean
(
unpert_last_hidden
,
dim
=
1
))
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
label
=
torch
.
tensor
([
label_class
],
device
=
device
,
dtype
=
torch
.
long
)
unpert_discrim_loss
=
ce_loss
(
prediction
,
label
)
print
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment