Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
afc7dcd9
Commit
afc7dcd9
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Now run_pplm works on cpu. Identical output as before (when using gpu).
parent
61399e5a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
12 deletions
+16
-12
examples/run_pplm.py
examples/run_pplm.py
+16
-12
No files found.
examples/run_pplm.py
View file @
afc7dcd9
...
@@ -84,9 +84,11 @@ DISCRIMINATOR_MODELS_PARAMS = {
...
@@ -84,9 +84,11 @@ DISCRIMINATOR_MODELS_PARAMS = {
}
}
def
to_var
(
x
,
requires_grad
=
False
,
volatile
=
False
):
def
to_var
(
x
,
requires_grad
=
False
,
volatile
=
False
,
device
=
'cuda'
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
()
and
device
==
'cuda'
:
x
=
x
.
cuda
()
x
=
x
.
cuda
()
elif
device
!=
'cuda'
:
x
=
x
.
to
(
device
)
return
Variable
(
x
,
requires_grad
=
requires_grad
,
volatile
=
volatile
)
return
Variable
(
x
,
requires_grad
=
requires_grad
,
volatile
=
volatile
)
...
@@ -182,7 +184,7 @@ def perturb_past(
...
@@ -182,7 +184,7 @@ def perturb_past(
for
i
in
range
(
num_iterations
):
for
i
in
range
(
num_iterations
):
print
(
"Iteration "
,
i
+
1
)
print
(
"Iteration "
,
i
+
1
)
curr_perturbation
=
[
curr_perturbation
=
[
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
,
device
=
device
)
for
p_
in
grad_accumulator
for
p_
in
grad_accumulator
]
]
...
@@ -290,7 +292,7 @@ def perturb_past(
...
@@ -290,7 +292,7 @@ def perturb_past(
# apply the accumulated perturbations to the past
# apply the accumulated perturbations to the past
grad_accumulator
=
[
grad_accumulator
=
[
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
)
to_var
(
torch
.
from_numpy
(
p_
),
requires_grad
=
True
,
device
=
device
)
for
p_
in
grad_accumulator
for
p_
in
grad_accumulator
]
]
pert_past
=
list
(
map
(
add
,
past
,
grad_accumulator
))
pert_past
=
list
(
map
(
add
,
past
,
grad_accumulator
))
...
@@ -300,7 +302,7 @@ def perturb_past(
...
@@ -300,7 +302,7 @@ def perturb_past(
def
get_classifier
(
def
get_classifier
(
name
:
Optional
[
str
],
label_class
:
Union
[
str
,
int
],
name
:
Optional
[
str
],
label_class
:
Union
[
str
,
int
],
device
:
Union
[
str
,
torch
.
device
]
device
:
str
)
->
Tuple
[
Optional
[
ClassificationHead
],
Optional
[
int
]]:
)
->
Tuple
[
Optional
[
ClassificationHead
],
Optional
[
int
]]:
if
name
is
None
:
if
name
is
None
:
return
None
,
None
return
None
,
None
...
@@ -355,16 +357,16 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
...
@@ -355,16 +357,16 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
return
bow_indices
return
bow_indices
def
build_bows_one_hot_vectors
(
bow_indices
):
def
build_bows_one_hot_vectors
(
bow_indices
,
device
=
'cuda'
):
if
bow_indices
is
None
:
if
bow_indices
is
None
:
return
None
return
None
one_hot_bows_vectors
=
[]
one_hot_bows_vectors
=
[]
for
single_bow
in
bow_indices
:
for
single_bow
in
bow_indices
:
single_bow
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
single_bow
))
single_bow
=
list
(
filter
(
lambda
x
:
len
(
x
)
<=
1
,
single_bow
))
single_bow
=
torch
.
tensor
(
single_bow
).
cuda
(
)
single_bow
=
torch
.
tensor
(
single_bow
).
to
(
device
)
num_words
=
single_bow
.
shape
[
0
]
num_words
=
single_bow
.
shape
[
0
]
one_hot_bow
=
torch
.
zeros
(
num_words
,
TOKENIZER
.
vocab_size
).
cuda
(
)
one_hot_bow
=
torch
.
zeros
(
num_words
,
TOKENIZER
.
vocab_size
).
to
(
device
)
one_hot_bow
.
scatter_
(
1
,
single_bow
,
1
)
one_hot_bow
.
scatter_
(
1
,
single_bow
,
1
)
one_hot_bows_vectors
.
append
(
one_hot_bow
)
one_hot_bows_vectors
.
append
(
one_hot_bow
)
return
one_hot_bows_vectors
return
one_hot_bows_vectors
...
@@ -425,7 +427,8 @@ def full_text_generation(
...
@@ -425,7 +427,8 @@ def full_text_generation(
length
=
length
,
length
=
length
,
perturb
=
False
perturb
=
False
)
)
torch
.
cuda
.
empty_cache
()
if
device
==
'cuda'
:
torch
.
cuda
.
empty_cache
()
pert_gen_tok_texts
=
[]
pert_gen_tok_texts
=
[]
discrim_losses
=
[]
discrim_losses
=
[]
...
@@ -460,7 +463,8 @@ def full_text_generation(
...
@@ -460,7 +463,8 @@ def full_text_generation(
discrim_losses
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
discrim_losses
.
append
(
discrim_loss
.
data
.
cpu
().
numpy
())
losses_in_time
.
append
(
loss_in_time
)
losses_in_time
.
append
(
loss_in_time
)
torch
.
cuda
.
empty_cache
()
if
device
==
'cuda'
:
torch
.
cuda
.
empty_cache
()
return
unpert_gen_tok_text
,
pert_gen_tok_texts
,
discrim_losses
,
losses_in_time
return
unpert_gen_tok_text
,
pert_gen_tok_texts
,
discrim_losses
,
losses_in_time
...
@@ -496,7 +500,7 @@ def generate_text_pplm(
...
@@ -496,7 +500,7 @@ def generate_text_pplm(
)
)
# collect one hot vectors for bags of words
# collect one hot vectors for bags of words
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
)
one_hot_bows_vectors
=
build_bows_one_hot_vectors
(
bow_indices
,
device
)
grad_norms
=
None
grad_norms
=
None
last
=
None
last
=
None
...
@@ -563,7 +567,7 @@ def generate_text_pplm(
...
@@ -563,7 +567,7 @@ def generate_text_pplm(
if
classifier
is
not
None
:
if
classifier
is
not
None
:
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
ce_loss
=
torch
.
nn
.
CrossEntropyLoss
()
prediction
=
classifier
(
torch
.
mean
(
unpert_last_hidden
,
dim
=
1
))
prediction
=
classifier
(
torch
.
mean
(
unpert_last_hidden
,
dim
=
1
))
label
=
torch
.
tensor
([
label_class
],
device
=
'cuda'
,
label
=
torch
.
tensor
([
label_class
],
device
=
device
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
unpert_discrim_loss
=
ce_loss
(
prediction
,
label
)
unpert_discrim_loss
=
ce_loss
(
prediction
,
label
)
print
(
print
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment