Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0b51fba2
"...slim/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "6389fb0d214f5dcfddf87ad0440439467ee9fbe6"
Commit
0b51fba2
authored
Nov 26, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Added script for training a discriminator for pplm to use
parent
34a83faa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
583 additions
and
18 deletions
+583
-18
examples/run_pplm.py
examples/run_pplm.py
+1
-18
examples/run_pplm_discrim_train.py
examples/run_pplm_discrim_train.py
+582
-0
No files found.
examples/run_pplm.py
View file @
0b51fba2
...
...
@@ -34,6 +34,7 @@ import torch.nn.functional as F
from
torch.autograd
import
Variable
from
tqdm
import
trange
from
examples.run_pplm_discrim_train
import
ClassificationHead
from
transformers
import
GPT2Tokenizer
from
transformers.file_utils
import
cached_path
from
transformers.modeling_gpt2
import
GPT2LMHeadModel
...
...
@@ -108,24 +109,6 @@ def top_k_filter(logits, k, probs=False):
logits
)
class
ClassificationHead
(
torch
.
nn
.
Module
):
""" Classification Head for the transformer """
def
__init__
(
self
,
class_size
=
5
,
embed_size
=
2048
):
super
(
ClassificationHead
,
self
).
__init__
()
self
.
class_size
=
class_size
self
.
embed_size
=
embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self
.
mlp
=
torch
.
nn
.
Linear
(
embed_size
,
class_size
)
def
forward
(
self
,
hidden_state
):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits
=
self
.
mlp
(
hidden_state
)
return
logits
def
perturb_past
(
past
,
model
,
prev
,
args
,
classifier
,
good_index
=
None
,
stepsize
=
0.01
,
vocab_size
=
50257
,
original_probs
=
None
,
accumulated_hidden
=
None
,
true_past
=
None
,
...
...
examples/run_pplm_discrim_train.py
0 → 100644
View file @
0b51fba2
#! /usr/bin/env python3
# coding=utf-8
# This code is licensed under a non-commercial license.
import
argparse
import
csv
import
json
import
math
import
time
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torch.optim
import
torch.optim
as
optim
import
torch.utils.data
as
data
from
nltk.tokenize.treebank
import
TreebankWordDetokenizer
from
torchtext
import
data
as
torchtext_data
from
torchtext
import
datasets
from
transformers
import
GPT2Tokenizer
,
GPT2LMHeadModel
torch
.
manual_seed
(
0
)
np
.
random
.
seed
(
0
)
EPSILON
=
1e-10
device
=
'cpu'
example_sentence
=
"This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq
=
100
class
ClassificationHead
(
torch
.
nn
.
Module
):
"""Classification Head for transformer encoders"""
def
__init__
(
self
,
class_size
,
embed_size
):
super
(
ClassificationHead
,
self
).
__init__
()
self
.
class_size
=
class_size
self
.
embed_size
=
embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self
.
mlp
=
torch
.
nn
.
Linear
(
embed_size
,
class_size
)
def
forward
(
self
,
hidden_state
):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits
=
self
.
mlp
(
hidden_state
)
return
logits
class
Discriminator
(
torch
.
nn
.
Module
):
"""Transformer encoder followed by a Classification Head"""
def
__init__
(
self
,
class_size
,
pretrained_model
=
"gpt2-medium"
,
cached_mode
=
False
):
super
(
Discriminator
,
self
).
__init__
()
self
.
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
pretrained_model
)
self
.
encoder
=
GPT2LMHeadModel
.
from_pretrained
(
pretrained_model
)
self
.
embed_size
=
self
.
encoder
.
transformer
.
config
.
hidden_size
self
.
classifier_head
=
ClassificationHead
(
class_size
=
class_size
,
embed_size
=
self
.
embed_size
)
self
.
cached_mode
=
cached_mode
def
get_classifier
(
self
):
return
self
.
classifier_head
def
train_custom
(
self
):
for
param
in
self
.
encoder
.
parameters
():
param
.
requires_grad
=
False
pass
self
.
classifier_head
.
train
()
def
avg_representation
(
self
,
x
):
mask
=
x
.
ne
(
0
).
unsqueeze
(
2
).
repeat
(
1
,
1
,
self
.
embed_size
).
float
().
to
(
device
).
detach
()
hidden
,
_
=
self
.
encoder
.
transformer
(
x
)
masked_hidden
=
hidden
*
mask
avg_hidden
=
torch
.
sum
(
masked_hidden
,
dim
=
1
)
/
(
torch
.
sum
(
mask
,
dim
=
1
).
detach
()
+
EPSILON
)
return
avg_hidden
def
forward
(
self
,
x
):
if
self
.
cached_mode
:
avg_hidden
=
x
.
to
(
device
)
else
:
avg_hidden
=
self
.
avg_representation
(
x
)
logits
=
self
.
classifier_head
(
avg_hidden
)
probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
return
probs
class
Dataset
(
data
.
Dataset
):
def
__init__
(
self
,
X
,
y
):
"""Reads source and target sequences from txt files."""
self
.
X
=
X
self
.
y
=
y
def
__len__
(
self
):
return
len
(
self
.
X
)
def
__getitem__
(
self
,
index
):
"""Returns one data pair (source and target)."""
data
=
{}
data
[
'X'
]
=
self
.
X
[
index
]
data
[
'y'
]
=
self
.
y
[
index
]
return
data
def
collate_fn
(
data
):
def
pad_sequences
(
sequences
):
lengths
=
[
len
(
seq
)
for
seq
in
sequences
]
padded_sequences
=
torch
.
zeros
(
len
(
sequences
),
max
(
lengths
)
).
long
()
# padding index 0
for
i
,
seq
in
enumerate
(
sequences
):
end
=
lengths
[
i
]
padded_sequences
[
i
,
:
end
]
=
seq
[:
end
]
return
padded_sequences
,
lengths
item_info
=
{}
for
key
in
data
[
0
].
keys
():
item_info
[
key
]
=
[
d
[
key
]
for
d
in
data
]
x_batch
,
_
=
pad_sequences
(
item_info
[
'X'
])
y_batch
=
torch
.
tensor
(
item_info
[
'y'
],
dtype
=
torch
.
long
)
return
x_batch
,
y_batch
def
cached_collate_fn
(
data
):
item_info
=
{}
for
key
in
data
[
0
].
keys
():
item_info
[
key
]
=
[
d
[
key
]
for
d
in
data
]
x_batch
=
torch
.
cat
(
item_info
[
'X'
],
0
)
y_batch
=
torch
.
tensor
(
item_info
[
'y'
],
dtype
=
torch
.
long
)
return
x_batch
,
y_batch
def
train_epoch
(
data_loader
,
discriminator
,
optimizer
,
epoch
=
0
,
log_interval
=
10
):
samples_so_far
=
0
discriminator
.
train_custom
()
for
batch_idx
,
(
input_t
,
target_t
)
in
enumerate
(
data_loader
):
input_t
,
target_t
=
input_t
.
to
(
device
),
target_t
.
to
(
device
)
optimizer
.
zero_grad
()
output_t
=
discriminator
(
input_t
)
loss
=
F
.
nll_loss
(
output_t
,
target_t
)
loss
.
backward
(
retain_graph
=
True
)
optimizer
.
step
()
samples_so_far
+=
len
(
input_t
)
if
batch_idx
%
log_interval
==
0
:
print
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}'
.
format
(
epoch
+
1
,
samples_so_far
,
len
(
data_loader
.
dataset
),
100
*
samples_so_far
/
len
(
data_loader
.
dataset
),
loss
.
item
()
)
)
def
evaluate_performance
(
data_loader
,
discriminator
):
discriminator
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
input_t
,
target_t
in
data_loader
:
input_t
,
target_t
=
input_t
.
to
(
device
),
target_t
.
to
(
device
)
output_t
=
discriminator
(
input_t
)
# sum up batch loss
test_loss
+=
F
.
nll_loss
(
output_t
,
target_t
,
reduction
=
'sum'
).
item
()
# get the index of the max log-probability
pred_t
=
output_t
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred_t
.
eq
(
target_t
.
view_as
(
pred_t
)).
sum
().
item
()
test_loss
/=
len
(
data_loader
.
dataset
)
print
(
'Performance on test set: '
'Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'
.
format
(
test_loss
,
correct
,
len
(
data_loader
.
dataset
),
100.
*
correct
/
len
(
data_loader
.
dataset
)
)
)
def
predict
(
input_sentence
,
model
,
classes
,
cached
=
False
):
input_t
=
model
.
tokenizer
.
encode
(
input_sentence
)
input_t
=
torch
.
tensor
([
input_t
],
dtype
=
torch
.
long
)
if
cached
:
input_t
=
model
.
avg_representation
(
input_t
)
log_probs
=
model
(
input_t
).
data
.
cpu
().
numpy
().
flatten
().
tolist
()
print
(
'Input sentence:'
,
input_sentence
)
print
(
'Predictions:'
,
", "
.
join
(
"{}: {:.4f}"
.
format
(
c
,
math
.
exp
(
log_prob
))
for
c
,
log_prob
in
zip
(
classes
,
log_probs
)
))
def
get_cached_data_loader
(
dataset
,
batch_size
,
discriminator
,
shuffle
=
False
):
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
dataset
,
batch_size
=
batch_size
,
collate_fn
=
collate_fn
)
xs
=
[]
ys
=
[]
for
batch_idx
,
(
x
,
y
)
in
enumerate
(
data_loader
):
with
torch
.
no_grad
():
x
=
x
.
to
(
device
)
avg_rep
=
discriminator
.
avg_representation
(
x
).
cpu
().
detach
()
avg_rep_list
=
torch
.
unbind
(
avg_rep
.
unsqueeze
(
1
))
xs
+=
avg_rep_list
ys
+=
y
.
cpu
().
numpy
().
tolist
()
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
Dataset
(
xs
,
ys
),
batch_size
=
batch_size
,
shuffle
=
shuffle
,
collate_fn
=
cached_collate_fn
)
return
data_loader
def
train_discriminator
(
dataset
,
dataset_fp
=
None
,
pretrained_model
=
'gpt2-medium'
,
epochs
=
10
,
batch_size
=
64
,
log_interval
=
10
,
save_model
=
False
,
cached
=
False
,
use_cuda
=
False
):
if
use_cuda
:
global
device
device
=
'cuda'
print
(
'Preprocessing {} dataset...'
.
format
(
dataset
))
start
=
time
.
time
()
if
dataset
==
'SST'
:
idx2class
=
[
"positive"
,
"negative"
,
"very positive"
,
"very negative"
,
"neutral"
]
class2idx
=
{
c
:
i
for
i
,
c
in
enumerate
(
idx2class
)}
discriminator
=
Discriminator
(
class_size
=
len
(
idx2class
),
pretrained_model
=
pretrained_model
,
cached_mode
=
cached
).
to
(
device
)
text
=
torchtext_data
.
Field
()
label
=
torchtext_data
.
Field
(
sequential
=
False
)
train_data
,
val_data
,
test_data
=
datasets
.
SST
.
splits
(
text
,
label
,
fine_grained
=
True
,
train_subtrees
=
True
,
)
x
=
[]
y
=
[]
for
i
in
range
(
len
(
train_data
)):
seq
=
TreebankWordDetokenizer
().
detokenize
(
vars
(
train_data
[
i
])[
"text"
]
)
seq
=
discriminator
.
tokenizer
.
encode
(
seq
)
seq
=
torch
.
tensor
([
50256
]
+
seq
,
device
=
device
,
dtype
=
torch
.
long
)
x
.
append
(
seq
)
y
.
append
(
class2idx
[
vars
(
train_data
[
i
])[
"label"
]])
train_dataset
=
Dataset
(
x
,
y
)
test_x
=
[]
test_y
=
[]
for
i
in
range
(
len
(
test_data
)):
seq
=
TreebankWordDetokenizer
().
detokenize
(
vars
(
test_data
[
i
])[
"text"
]
)
seq
=
discriminator
.
tokenizer
.
encode
(
seq
)
seq
=
torch
.
tensor
([
50256
]
+
seq
,
device
=
device
,
dtype
=
torch
.
long
)
test_x
.
append
(
seq
)
test_y
.
append
(
class2idx
[
vars
(
test_data
[
i
])[
"label"
]])
test_dataset
=
Dataset
(
test_x
,
test_y
)
discriminator_meta
=
{
"class_size"
:
len
(
idx2class
),
"embed_size"
:
discriminator
.
embed_size
,
"pretrained_model"
:
pretrained_model
,
"class_vocab"
:
class2idx
,
"default_class"
:
2
,
}
elif
dataset
==
'clickbait'
:
idx2class
=
[
"non_clickbait"
,
"clickbait"
]
class2idx
=
{
c
:
i
for
i
,
c
in
enumerate
(
idx2class
)}
discriminator
=
Discriminator
(
class_size
=
len
(
idx2class
),
pretrained_model
=
pretrained_model
,
cached_mode
=
cached
).
to
(
device
)
with
open
(
"datasets/clickbait/clickbait_train_prefix.txt"
)
as
f
:
data
=
[]
for
i
,
line
in
enumerate
(
f
):
try
:
data
.
append
(
eval
(
line
))
except
:
print
(
'Error evaluating line {}: {}'
.
format
(
i
,
line
))
continue
x
=
[]
y
=
[]
y
=
[]
for
i
,
d
in
enumerate
(
data
):
try
:
seq
=
discriminator
.
tokenizer
.
encode
(
d
[
"text"
])
if
len
(
seq
)
<
max_length_seq
:
seq
=
torch
.
tensor
(
[
50256
]
+
seq
,
device
=
device
,
dtype
=
torch
.
long
)
else
:
print
(
"Line {} is longer than maximum length {}"
.
format
(
i
,
max_length_seq
))
continue
x
.
append
(
seq
)
y
.
append
(
d
[
'label'
])
except
:
print
(
"Error tokenizing line {}, skipping it"
.
format
(
i
))
pass
full_dataset
=
Dataset
(
x
,
y
)
train_size
=
int
(
0.9
*
len
(
full_dataset
))
test_size
=
len
(
full_dataset
)
-
train_size
train_dataset
,
test_dataset
=
torch
.
utils
.
data
.
random_split
(
full_dataset
,
[
train_size
,
test_size
]
)
discriminator_meta
=
{
"class_size"
:
len
(
idx2class
),
"embed_size"
:
discriminator
.
embed_size
,
"pretrained_model"
:
pretrained_model
,
"class_vocab"
:
class2idx
,
"default_class"
:
1
,
}
elif
dataset
==
'toxic'
:
idx2class
=
[
"non_toxic"
,
"toxic"
]
class2idx
=
{
c
:
i
for
i
,
c
in
enumerate
(
idx2class
)}
discriminator
=
Discriminator
(
class_size
=
len
(
idx2class
),
pretrained_model
=
pretrained_model
,
cached_mode
=
cached
).
to
(
device
)
with
open
(
"datasets/toxic/toxic_train.txt"
)
as
f
:
data
=
[]
for
i
,
line
in
enumerate
(
f
):
try
:
data
.
append
(
eval
(
line
))
except
:
print
(
'Error evaluating line {}: {}'
.
format
(
i
,
line
))
continue
x
=
[]
y
=
[]
for
i
,
d
in
enumerate
(
data
):
try
:
seq
=
discriminator
.
tokenizer
.
encode
(
d
[
"text"
])
if
len
(
seq
)
<
max_length_seq
:
seq
=
torch
.
tensor
(
[
50256
]
+
seq
,
device
=
device
,
dtype
=
torch
.
long
)
else
:
print
(
"Line {} is longer than maximum length {}"
.
format
(
i
,
max_length_seq
))
continue
x
.
append
(
seq
)
y
.
append
(
int
(
np
.
sum
(
d
[
'label'
])
>
0
))
except
:
print
(
"Error tokenizing line {}, skipping it"
.
format
(
i
))
pass
full_dataset
=
Dataset
(
x
,
y
)
train_size
=
int
(
0.9
*
len
(
full_dataset
))
test_size
=
len
(
full_dataset
)
-
train_size
train_dataset
,
test_dataset
=
torch
.
utils
.
data
.
random_split
(
full_dataset
,
[
train_size
,
test_size
]
)
discriminator_meta
=
{
"class_size"
:
len
(
idx2class
),
"embed_size"
:
discriminator
.
embed_size
,
"pretrained_model"
:
pretrained_model
,
"class_vocab"
:
class2idx
,
"default_class"
:
0
,
}
else
:
# if dataset == 'generic':
# This assumes the input dataset is a TSV with the following structure:
# class \t text
if
dataset_fp
is
None
:
raise
ValueError
(
'When generic dataset is selected, '
'dataset_fp needs to be specified aswell.'
)
classes
=
set
()
with
open
(
dataset_fp
)
as
f
:
csv_reader
=
csv
.
reader
(
f
,
delimiter
=
'
\t
'
)
for
row
in
csv_reader
:
classes
.
add
(
row
[
0
])
idx2class
=
sorted
(
classes
)
class2idx
=
{
c
:
i
for
i
,
c
in
enumerate
(
idx2class
)}
discriminator
=
Discriminator
(
class_size
=
len
(
idx2class
),
pretrained_model
=
pretrained_model
,
cached_mode
=
cached
).
to
(
device
)
x
=
[]
y
=
[]
with
open
(
dataset_fp
)
as
f
:
csv_reader
=
csv
.
reader
(
f
,
delimiter
=
'
\t
'
)
for
i
,
row
in
enumerate
(
csv_reader
):
label
=
row
[
0
]
text
=
row
[
1
]
try
:
seq
=
discriminator
.
tokenizer
.
encode
(
text
)
if
(
len
(
seq
)
<
max_length_seq
):
seq
=
torch
.
tensor
(
[
50256
]
+
seq
,
device
=
device
,
dtype
=
torch
.
long
)
else
:
print
(
"Line {} is longer than maximum length {}"
.
format
(
i
,
max_length_seq
))
continue
x
.
append
(
seq
)
y
.
append
(
class2idx
[
label
])
except
:
print
(
"Error tokenizing line {}, skipping it"
.
format
(
i
))
pass
full_dataset
=
Dataset
(
x
,
y
)
train_size
=
int
(
0.9
*
len
(
full_dataset
))
test_size
=
len
(
full_dataset
)
-
train_size
train_dataset
,
test_dataset
=
torch
.
utils
.
data
.
random_split
(
full_dataset
,
[
train_size
,
test_size
]
)
discriminator_meta
=
{
"class_size"
:
len
(
idx2class
),
"embed_size"
:
discriminator
.
embed_size
,
"pretrained_model"
:
pretrained_model
,
"class_vocab"
:
class2idx
,
"default_class"
:
0
,
}
end
=
time
.
time
()
print
(
'Preprocessed {} data points'
.
format
(
len
(
train_dataset
)
+
len
(
test_dataset
))
)
print
(
"Data preprocessing took: {:.3f}s"
.
format
(
end
-
start
))
if
cached
:
start
=
time
.
time
()
train_loader
=
get_cached_data_loader
(
train_dataset
,
batch_size
,
discriminator
,
shuffle
=
True
)
test_loader
=
get_cached_data_loader
(
test_dataset
,
batch_size
,
discriminator
)
end
=
time
.
time
()
print
(
"Building representation cache took: {:.3f}s"
.
format
(
end
-
start
))
else
:
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
test_dataset
,
batch_size
=
batch_size
,
collate_fn
=
collate_fn
)
if
save_model
:
with
open
(
"{}_classifier_head_meta.json"
.
format
(
dataset
),
"w"
)
as
meta_file
:
json
.
dump
(
discriminator_meta
,
meta_file
)
optimizer
=
optim
.
Adam
(
discriminator
.
parameters
(),
lr
=
0.0001
)
for
epoch
in
range
(
epochs
):
start
=
time
.
time
()
print
(
'
\n
Epoch'
,
epoch
+
1
)
train_epoch
(
discriminator
=
discriminator
,
data_loader
=
train_loader
,
optimizer
=
optimizer
,
epoch
=
epoch
,
log_interval
=
log_interval
)
evaluate_performance
(
data_loader
=
test_loader
,
discriminator
=
discriminator
)
end
=
time
.
time
()
print
(
"Epoch took: {:.3f}s"
.
format
(
end
-
start
))
print
(
"
\n
Example prediction"
)
predict
(
example_sentence
,
discriminator
,
idx2class
,
cached
)
if
save_model
:
# torch.save(discriminator.state_dict(),
# "{}_discriminator_{}.pt".format(
# args.dataset, epoch
# ))
torch
.
save
(
discriminator
.
get_classifier
().
state_dict
(),
"{}_classifier_head_epoch_{}.pt"
.
format
(
dataset
,
epoch
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a discriminator on top of GPT-2 representations'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
'SST'
,
choices
=
(
'SST'
,
'clickbait'
,
'toxic'
,
'generic'
),
help
=
'dataset to train the discriminator on.'
'In case of generic, the dataset is expected'
'to be a TSBV file with structure: class
\\
t text'
)
parser
.
add_argument
(
'--dataset_fp'
,
type
=
str
,
default
=
''
,
help
=
'File path of the dataset to use. '
'Needed only in case of generic datadset'
)
parser
.
add_argument
(
'--pretrained_model'
,
type
=
str
,
default
=
'gpt2-medium'
,
help
=
'Pretrained model to use as encoder'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
help
=
'Number of training epochs'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
64
,
metavar
=
'N'
,
help
=
'input batch size for training (default: 64)'
)
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
help
=
'how many batches to wait before logging training status'
)
parser
.
add_argument
(
'--save_model'
,
action
=
'store_true'
,
help
=
'whether to save the model'
)
parser
.
add_argument
(
'--cached'
,
action
=
'store_true'
,
help
=
'whether to cache the input representations'
)
parser
.
add_argument
(
'--use_cuda'
,
action
=
'store_true'
,
help
=
'use to turn on cuda'
)
args
=
parser
.
parse_args
()
train_discriminator
(
**
(
vars
(
args
)))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment