Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7469d03b
Commit
7469d03b
authored
Nov 26, 2019
by
w4nderlust
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Fixed minor bug when running training on cuda
parent
0b51fba2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
27 deletions
+30
-27
examples/run_pplm_discrim_train.py
examples/run_pplm_discrim_train.py
+30
-27
No files found.
examples/run_pplm_discrim_train.py
View file @
7469d03b
...
@@ -18,6 +18,7 @@ import torch.utils.data as data
...
@@ -18,6 +18,7 @@ import torch.utils.data as data
from
nltk.tokenize.treebank
import
TreebankWordDetokenizer
from
nltk.tokenize.treebank
import
TreebankWordDetokenizer
from
torchtext
import
data
as
torchtext_data
from
torchtext
import
data
as
torchtext_data
from
torchtext
import
datasets
from
torchtext
import
datasets
from
transformers
import
GPT2Tokenizer
,
GPT2LMHeadModel
from
transformers
import
GPT2Tokenizer
,
GPT2LMHeadModel
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -89,7 +90,7 @@ class Discriminator(torch.nn.Module):
...
@@ -89,7 +90,7 @@ class Discriminator(torch.nn.Module):
if
self
.
cached_mode
:
if
self
.
cached_mode
:
avg_hidden
=
x
.
to
(
device
)
avg_hidden
=
x
.
to
(
device
)
else
:
else
:
avg_hidden
=
self
.
avg_representation
(
x
)
avg_hidden
=
self
.
avg_representation
(
x
.
to
(
device
)
)
logits
=
self
.
classifier_head
(
avg_hidden
)
logits
=
self
.
classifier_head
(
avg_hidden
)
probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
...
@@ -203,7 +204,7 @@ def evaluate_performance(data_loader, discriminator):
...
@@ -203,7 +204,7 @@ def evaluate_performance(data_loader, discriminator):
def
predict
(
input_sentence
,
model
,
classes
,
cached
=
False
):
def
predict
(
input_sentence
,
model
,
classes
,
cached
=
False
):
input_t
=
model
.
tokenizer
.
encode
(
input_sentence
)
input_t
=
model
.
tokenizer
.
encode
(
input_sentence
)
input_t
=
torch
.
tensor
([
input_t
],
dtype
=
torch
.
long
)
input_t
=
torch
.
tensor
([
input_t
],
dtype
=
torch
.
long
,
device
=
device
)
if
cached
:
if
cached
:
input_t
=
model
.
avg_representation
(
input_t
)
input_t
=
model
.
avg_representation
(
input_t
)
...
@@ -428,7 +429,8 @@ def train_discriminator(
...
@@ -428,7 +429,8 @@ def train_discriminator(
with
open
(
dataset_fp
)
as
f
:
with
open
(
dataset_fp
)
as
f
:
csv_reader
=
csv
.
reader
(
f
,
delimiter
=
'
\t
'
)
csv_reader
=
csv
.
reader
(
f
,
delimiter
=
'
\t
'
)
for
row
in
csv_reader
:
for
row
in
csv_reader
:
classes
.
add
(
row
[
0
])
if
row
:
classes
.
add
(
row
[
0
])
idx2class
=
sorted
(
classes
)
idx2class
=
sorted
(
classes
)
class2idx
=
{
c
:
i
for
i
,
c
in
enumerate
(
idx2class
)}
class2idx
=
{
c
:
i
for
i
,
c
in
enumerate
(
idx2class
)}
...
@@ -444,30 +446,31 @@ def train_discriminator(
...
@@ -444,30 +446,31 @@ def train_discriminator(
with
open
(
dataset_fp
)
as
f
:
with
open
(
dataset_fp
)
as
f
:
csv_reader
=
csv
.
reader
(
f
,
delimiter
=
'
\t
'
)
csv_reader
=
csv
.
reader
(
f
,
delimiter
=
'
\t
'
)
for
i
,
row
in
enumerate
(
csv_reader
):
for
i
,
row
in
enumerate
(
csv_reader
):
label
=
row
[
0
]
if
row
:
text
=
row
[
1
]
label
=
row
[
0
]
text
=
row
[
1
]
try
:
seq
=
discriminator
.
tokenizer
.
encode
(
text
)
try
:
if
(
len
(
seq
)
<
max_length_seq
):
seq
=
discriminator
.
tokenizer
.
encode
(
text
)
seq
=
torch
.
tensor
(
if
(
len
(
seq
)
<
max_length_seq
):
[
50256
]
+
seq
,
seq
=
torch
.
tensor
(
device
=
device
,
[
50256
]
+
seq
,
dtype
=
torch
.
long
device
=
device
,
)
dtype
=
torch
.
long
)
else
:
print
(
"Line {} is longer than maximum length {}"
.
format
(
else
:
i
,
max_length_seq
print
(
"Line {} is longer than maximum length {}"
.
format
(
))
i
,
max_length_seq
continue
))
continue
x
.
append
(
seq
)
y
.
append
(
class2idx
[
label
])
x
.
append
(
seq
)
y
.
append
(
class2idx
[
label
])
except
:
print
(
"Error tokenizing line {}, skipping it"
.
format
(
i
))
except
:
pass
print
(
"Error tokenizing line {}, skipping it"
.
format
(
i
))
pass
full_dataset
=
Dataset
(
x
,
y
)
full_dataset
=
Dataset
(
x
,
y
)
train_size
=
int
(
0.9
*
len
(
full_dataset
))
train_size
=
int
(
0.9
*
len
(
full_dataset
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment