Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
75904dae
"runner/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e2252d0fc6ea5c410b1ac4fa0a722beda78b3431"
Commit
75904dae
authored
Nov 29, 2019
by
w4nderlust
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Removed global variable device
parent
7fd54b55
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
19 deletions
+28
-19
examples/run_pplm_discrim_train.py
examples/run_pplm_discrim_train.py
+28
-19
No files found.
examples/run_pplm_discrim_train.py
View file @
75904dae
...
@@ -25,7 +25,6 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
...
@@ -25,7 +25,6 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
EPSILON
=
1e-10
EPSILON
=
1e-10
device
=
"cpu"
example_sentence
=
"This is incredible! I love it, this is the best chicken I have ever had."
example_sentence
=
"This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq
=
100
max_length_seq
=
100
...
@@ -55,7 +54,8 @@ class Discriminator(torch.nn.Module):
...
@@ -55,7 +54,8 @@ class Discriminator(torch.nn.Module):
self
,
self
,
class_size
,
class_size
,
pretrained_model
=
"gpt2-medium"
,
pretrained_model
=
"gpt2-medium"
,
cached_mode
=
False
cached_mode
=
False
,
device
=
'cpu'
):
):
super
(
Discriminator
,
self
).
__init__
()
super
(
Discriminator
,
self
).
__init__
()
self
.
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
pretrained_model
)
self
.
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
pretrained_model
)
...
@@ -66,6 +66,7 @@ class Discriminator(torch.nn.Module):
...
@@ -66,6 +66,7 @@ class Discriminator(torch.nn.Module):
embed_size
=
self
.
embed_size
embed_size
=
self
.
embed_size
)
)
self
.
cached_mode
=
cached_mode
self
.
cached_mode
=
cached_mode
self
.
device
=
device
def
get_classifier
(
self
):
def
get_classifier
(
self
):
return
self
.
classifier_head
return
self
.
classifier_head
...
@@ -78,7 +79,7 @@ class Discriminator(torch.nn.Module):
...
@@ -78,7 +79,7 @@ class Discriminator(torch.nn.Module):
def
avg_representation
(
self
,
x
):
def
avg_representation
(
self
,
x
):
mask
=
x
.
ne
(
0
).
unsqueeze
(
2
).
repeat
(
mask
=
x
.
ne
(
0
).
unsqueeze
(
2
).
repeat
(
1
,
1
,
self
.
embed_size
1
,
1
,
self
.
embed_size
).
float
().
to
(
device
).
detach
()
).
float
().
to
(
self
.
device
).
detach
()
hidden
,
_
=
self
.
encoder
.
transformer
(
x
)
hidden
,
_
=
self
.
encoder
.
transformer
(
x
)
masked_hidden
=
hidden
*
mask
masked_hidden
=
hidden
*
mask
avg_hidden
=
torch
.
sum
(
masked_hidden
,
dim
=
1
)
/
(
avg_hidden
=
torch
.
sum
(
masked_hidden
,
dim
=
1
)
/
(
...
@@ -88,9 +89,9 @@ class Discriminator(torch.nn.Module):
...
@@ -88,9 +89,9 @@ class Discriminator(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
cached_mode
:
if
self
.
cached_mode
:
avg_hidden
=
x
.
to
(
device
)
avg_hidden
=
x
.
to
(
self
.
device
)
else
:
else
:
avg_hidden
=
self
.
avg_representation
(
x
.
to
(
device
))
avg_hidden
=
self
.
avg_representation
(
x
.
to
(
self
.
device
))
logits
=
self
.
classifier_head
(
avg_hidden
)
logits
=
self
.
classifier_head
(
avg_hidden
)
probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
...
@@ -152,7 +153,7 @@ def cached_collate_fn(data):
...
@@ -152,7 +153,7 @@ def cached_collate_fn(data):
def
train_epoch
(
data_loader
,
discriminator
,
optimizer
,
def
train_epoch
(
data_loader
,
discriminator
,
optimizer
,
epoch
=
0
,
log_interval
=
10
):
epoch
=
0
,
log_interval
=
10
,
device
=
'cpu'
):
samples_so_far
=
0
samples_so_far
=
0
discriminator
.
train_custom
()
discriminator
.
train_custom
()
for
batch_idx
,
(
input_t
,
target_t
)
in
enumerate
(
data_loader
):
for
batch_idx
,
(
input_t
,
target_t
)
in
enumerate
(
data_loader
):
...
@@ -177,7 +178,7 @@ def train_epoch(data_loader, discriminator, optimizer,
...
@@ -177,7 +178,7 @@ def train_epoch(data_loader, discriminator, optimizer,
)
)
def
evaluate_performance
(
data_loader
,
discriminator
):
def
evaluate_performance
(
data_loader
,
discriminator
,
device
=
'cpu'
):
discriminator
.
eval
()
discriminator
.
eval
()
test_loss
=
0
test_loss
=
0
correct
=
0
correct
=
0
...
@@ -202,7 +203,7 @@ def evaluate_performance(data_loader, discriminator):
...
@@ -202,7 +203,7 @@ def evaluate_performance(data_loader, discriminator):
)
)
def
predict
(
input_sentence
,
model
,
classes
,
cached
=
False
):
def
predict
(
input_sentence
,
model
,
classes
,
cached
=
False
,
device
=
'cpu'
):
input_t
=
model
.
tokenizer
.
encode
(
input_sentence
)
input_t
=
model
.
tokenizer
.
encode
(
input_sentence
)
input_t
=
torch
.
tensor
([
input_t
],
dtype
=
torch
.
long
,
device
=
device
)
input_t
=
torch
.
tensor
([
input_t
],
dtype
=
torch
.
long
,
device
=
device
)
if
cached
:
if
cached
:
...
@@ -216,7 +217,8 @@ def predict(input_sentence, model, classes, cached=False):
...
@@ -216,7 +217,8 @@ def predict(input_sentence, model, classes, cached=False):
))
))
def
get_cached_data_loader
(
dataset
,
batch_size
,
discriminator
,
shuffle
=
False
):
def
get_cached_data_loader
(
dataset
,
batch_size
,
discriminator
,
shuffle
=
False
,
device
=
'cpu'
):
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
dataset
,
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
dataset
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
collate_fn
=
collate_fn
)
collate_fn
=
collate_fn
)
...
@@ -244,7 +246,6 @@ def train_discriminator(
...
@@ -244,7 +246,6 @@ def train_discriminator(
dataset
,
dataset_fp
=
None
,
pretrained_model
=
"gpt2-medium"
,
dataset
,
dataset_fp
=
None
,
pretrained_model
=
"gpt2-medium"
,
epochs
=
10
,
batch_size
=
64
,
log_interval
=
10
,
epochs
=
10
,
batch_size
=
64
,
log_interval
=
10
,
save_model
=
False
,
cached
=
False
,
no_cuda
=
False
):
save_model
=
False
,
cached
=
False
,
no_cuda
=
False
):
global
device
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
no_cuda
else
"cpu"
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
no_cuda
else
"cpu"
print
(
"Preprocessing {} dataset..."
.
format
(
dataset
))
print
(
"Preprocessing {} dataset..."
.
format
(
dataset
))
...
@@ -258,7 +259,8 @@ def train_discriminator(
...
@@ -258,7 +259,8 @@ def train_discriminator(
discriminator
=
Discriminator
(
discriminator
=
Discriminator
(
class_size
=
len
(
idx2class
),
class_size
=
len
(
idx2class
),
pretrained_model
=
pretrained_model
,
pretrained_model
=
pretrained_model
,
cached_mode
=
cached
cached_mode
=
cached
,
device
=
device
).
to
(
device
)
).
to
(
device
)
text
=
torchtext_data
.
Field
()
text
=
torchtext_data
.
Field
()
...
@@ -309,7 +311,8 @@ def train_discriminator(
...
@@ -309,7 +311,8 @@ def train_discriminator(
discriminator
=
Discriminator
(
discriminator
=
Discriminator
(
class_size
=
len
(
idx2class
),
class_size
=
len
(
idx2class
),
pretrained_model
=
pretrained_model
,
pretrained_model
=
pretrained_model
,
cached_mode
=
cached
cached_mode
=
cached
,
device
=
device
).
to
(
device
)
).
to
(
device
)
with
open
(
"datasets/clickbait/clickbait_train_prefix.txt"
)
as
f
:
with
open
(
"datasets/clickbait/clickbait_train_prefix.txt"
)
as
f
:
...
@@ -368,7 +371,8 @@ def train_discriminator(
...
@@ -368,7 +371,8 @@ def train_discriminator(
discriminator
=
Discriminator
(
discriminator
=
Discriminator
(
class_size
=
len
(
idx2class
),
class_size
=
len
(
idx2class
),
pretrained_model
=
pretrained_model
,
pretrained_model
=
pretrained_model
,
cached_mode
=
cached
cached_mode
=
cached
,
device
=
device
).
to
(
device
)
).
to
(
device
)
x
=
[]
x
=
[]
...
@@ -431,7 +435,8 @@ def train_discriminator(
...
@@ -431,7 +435,8 @@ def train_discriminator(
discriminator
=
Discriminator
(
discriminator
=
Discriminator
(
class_size
=
len
(
idx2class
),
class_size
=
len
(
idx2class
),
pretrained_model
=
pretrained_model
,
pretrained_model
=
pretrained_model
,
cached_mode
=
cached
cached_mode
=
cached
,
device
=
device
).
to
(
device
)
).
to
(
device
)
x
=
[]
x
=
[]
...
@@ -494,11 +499,12 @@ def train_discriminator(
...
@@ -494,11 +499,12 @@ def train_discriminator(
start
=
time
.
time
()
start
=
time
.
time
()
train_loader
=
get_cached_data_loader
(
train_loader
=
get_cached_data_loader
(
train_dataset
,
batch_size
,
discriminator
,
shuffle
=
True
train_dataset
,
batch_size
,
discriminator
,
shuffle
=
True
,
device
=
device
)
)
test_loader
=
get_cached_data_loader
(
test_loader
=
get_cached_data_loader
(
test_dataset
,
batch_size
,
discriminator
test_dataset
,
batch_size
,
discriminator
,
device
=
device
)
)
end
=
time
.
time
()
end
=
time
.
time
()
...
@@ -529,18 +535,21 @@ def train_discriminator(
...
@@ -529,18 +535,21 @@ def train_discriminator(
data_loader
=
train_loader
,
data_loader
=
train_loader
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
epoch
=
epoch
,
epoch
=
epoch
,
log_interval
=
log_interval
log_interval
=
log_interval
,
device
=
device
)
)
evaluate_performance
(
evaluate_performance
(
data_loader
=
test_loader
,
data_loader
=
test_loader
,
discriminator
=
discriminator
discriminator
=
discriminator
,
device
=
device
)
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
"Epoch took: {:.3f}s"
.
format
(
end
-
start
))
print
(
"Epoch took: {:.3f}s"
.
format
(
end
-
start
))
print
(
"
\n
Example prediction"
)
print
(
"
\n
Example prediction"
)
predict
(
example_sentence
,
discriminator
,
idx2class
,
cached
)
predict
(
example_sentence
,
discriminator
,
idx2class
,
cached
=
cached
,
device
=
device
)
if
save_model
:
if
save_model
:
# torch.save(discriminator.state_dict(),
# torch.save(discriminator.state_dict(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment