Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f690f0e1
Commit
f690f0e1
authored
Nov 02, 2018
by
thomwolf
Browse files
run_classifier WIP + added classifier head and initialization to the model
parent
4a0b59e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
104 deletions
+129
-104
modeling_pytorch.py
modeling_pytorch.py
+28
-0
run_classifier_pytorch.py
run_classifier_pytorch.py
+101
-104
No files found.
modeling_pytorch.py
View file @
f690f0e1
...
@@ -27,6 +27,7 @@ import six
...
@@ -27,6 +27,7 @@ import six
import
tensorflow
as
tf
import
tensorflow
as
tf
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
def
gelu
(
x
):
def
gelu
(
x
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -394,3 +395,30 @@ class BertModel(nn.Module):
...
@@ -394,3 +395,30 @@ class BertModel(nn.Module):
sequence_output
=
all_encoder_layers
[
-
1
]
sequence_output
=
all_encoder_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
pooled_output
=
self
.
pooler
(
sequence_output
)
return
all_encoder_layers
,
pooled_output
return
all_encoder_layers
,
pooled_output
class
BertForSequenceClassification
(
nn
.
Module
):
def
__init__
(
self
,
config
,
num_labels
):
super
(
BertForSequenceClassification
,
self
).
__init__
()
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
def
init_weights
(
m
):
if
isinstance
(
m
)
==
nn
.
Linear
or
isinstance
(
m
)
==
nn
.
Embedding
:
print
(
"Initializing {}"
.
format
(
m
))
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
m
.
weight
.
normal_
(
config
.
initializer_range
)
self
.
apply
(
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
labels
=
None
):
_
,
pooled_output
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
)
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
,
labels
)
return
loss
,
logits
else
:
return
logits
run_classifier_pytorch.py
View file @
f690f0e1
...
@@ -20,20 +20,23 @@ from __future__ import print_function
...
@@ -20,20 +20,23 @@ from __future__ import print_function
import
csv
import
csv
import
os
import
os
from
modeling_pytorch
import
BertConfig
,
BertModel
import
logging
from
optimization_pytorch
import
BERTAdam
import
argparse
# import optimization
import
tokenization_pytorch
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data.distributed
import
DistributedSampler
import
tokenization_pytorch
from
modeling_pytorch
import
BertConfig
,
BertForSequenceClassification
from
optimization_pytorch
import
BERTAdam
import
logging
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
level
=
logging
.
INFO
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
## Required parameters
## Required parameters
...
@@ -116,7 +119,7 @@ parser.add_argument("--iterations_per_loop",
...
@@ -116,7 +119,7 @@ parser.add_argument("--iterations_per_loop",
default
=
1000
,
default
=
1000
,
type
=
int
,
type
=
int
,
help
=
"How many steps to make in each estimator call."
)
help
=
"How many steps to make in each estimator call."
)
parser
.
add_argument
(
"--no_cuda"
,
parser
.
add_argument
(
"--no_cuda"
,
default
=
False
,
default
=
False
,
type
=
bool
,
type
=
bool
,
...
@@ -127,39 +130,6 @@ parser.add_argument("--local_rank",
...
@@ -127,39 +130,6 @@ parser.add_argument("--local_rank",
default
=-
1
,
default
=-
1
,
help
=
"local_rank for distributed training on gpus"
)
help
=
"local_rank for distributed training on gpus"
)
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
parser
.
add_argument
(
"--use_tpu"
,
default
=
False
,
type
=
bool
,
help
=
"Whether to use TPU or GPU/CPU."
)
parser
.
add_argument
(
"--tpu_name"
,
default
=
None
,
type
=
str
,
help
=
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url."
)
parser
.
add_argument
(
"--tpu_zone"
,
default
=
None
,
type
=
str
,
help
=
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata."
)
parser
.
add_argument
(
"--gcp_project"
,
default
=
None
,
type
=
str
,
help
=
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata."
)
parser
.
add_argument
(
"--master"
,
default
=
None
,
type
=
str
,
help
=
"[Optional] TensorFlow master URL."
)
parser
.
add_argument
(
"--num_tpu_cores"
,
default
=
8
,
type
=
int
,
help
=
"Only used if `use_tpu` is True. Total number of TPU cores to use."
)
### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
class
InputExample
(
object
):
class
InputExample
(
object
):
...
@@ -429,44 +399,41 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
...
@@ -429,44 +399,41 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
tokens_b
.
pop
()
tokens_b
.
pop
()
def
input_fn_builder
(
features
,
seq_length
,
is_training
,
drop_remainder
):
def
input_fn_builder
(
features
,
seq_length
,
train_batch_size
):
# TODO: delete
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
### ATTENTION - To rewrite ###
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
### ATTENTION - To rewrite ###
all_input_ids
=
[]
all_input_ids
=
[
f
.
input_ids
for
feature
in
features
]
all_input_mask
=
[]
all_input_mask
=
[
f
.
input_mask
for
feature
in
features
]
all_segment_ids
=
[]
all_segment_ids
=
[
f
.
segment_ids
for
feature
in
features
]
all_label_ids
=
[]
all_label_ids
=
[
f
.
label_id
for
feature
in
features
]
for
feature
in
features
:
# for feature in features:
all_input_ids
.
append
(
feature
.
input_ids
)
# all_input_ids.append(feature.input_ids)
all_input_mask
.
append
(
feature
.
input_mask
)
# all_input_mask.append(feature.input_mask)
all_segment_ids
.
append
(
feature
.
segment_ids
)
# all_segment_ids.append(feature.segment_ids)
all_label_ids
.
append
(
feature
.
label_id
)
# all_label_ids.append(feature.label_id)
def
input_fn
(
params
):
input_ids_tensor
=
torch
.
tensor
(
all_input_ids
,
dtype
=
torch
.
Long
)
"""The actual input function."""
input_mask_tensor
=
torch
.
tensor
(
all_input_mask
,
dtype
=
torch
.
Long
)
batch_size
=
params
[
"batch_size"
]
segment_tensor
=
torch
.
tensor
(
all_segment
,
dtype
=
torch
.
Long
)
label_tensor
=
torch
.
tensor
(
all_label
,
dtype
=
torch
.
Long
)
num_examples
=
len
(
features
)
train_data
=
TensorDataset
(
input_ids_tensor
,
input_mask_tensor
,
device
=
torch
.
device
(
"cuda"
)
if
args
.
use_gpu
else
torch
.
device
(
"cpu"
)
segment_tensor
,
label_tensor
)
d
=
torch
.
utils
.
data
.
TensorDataset
({
## BUG THIS IS NOT WORKING.... ###
if
args
.
local_rank
==
-
1
:
"input_ids"
:
torch
.
IntTensor
(
all_input_ids
,
device
=
device
),
#Requires_grad=False by default
train_sampler
=
RandomSampler
(
train_data
)
"input_mask"
:
torch
.
IntTensor
(
all_input_mask
,
device
=
device
),
else
:
"segment_ids"
:
torch
.
IntTensor
(
all_segment_ids
,
device
=
device
),
train_sampler
=
DistributedSampler
(
train_data
)
"label_ids"
:
torch
.
IntTensor
(
all_label_ids
,
device
=
device
)
train_dataloader
=
DataLoader
(
train_data
,
sampler
=
train_sampler
,
batch_size
=
train_batch_size
)
})
shuffle
=
True
if
is_training
else
False
d
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
d
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_remainder
)
# Cf https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
return
d
return
input_fn
return
train_dataloader
def
accuracy
(
out
,
labels
):
outputs
=
np
.
argmax
(
out
,
axis
=
1
)
return
np
.
sum
(
outputs
==
labels
)
/
float
(
labels
.
size
)
def
main
(
_
):
def
main
():
processors
=
{
processors
=
{
"cola"
:
ColaProcessor
,
"cola"
:
ColaProcessor
,
"mnli"
:
MnliProcessor
,
"mnli"
:
MnliProcessor
,
...
@@ -492,7 +459,7 @@ def main(_):
...
@@ -492,7 +459,7 @@ def main(_):
"Cannot use sequence length %d because the BERT model "
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d"
%
"was only trained up to sequence length %d"
%
(
args
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
(
args
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
f
"Output directory (
{
args
.
output_dir
}
) already exists and is "
raise
ValueError
(
f
"Output directory (
{
args
.
output_dir
}
) already exists and is "
f
"not empty."
)
f
"not empty."
)
...
@@ -517,13 +484,13 @@ def main(_):
...
@@ -517,13 +484,13 @@ def main(_):
num_train_steps
=
int
(
num_train_steps
=
int
(
len
(
train_examples
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
len
(
train_examples
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
model
=
Bert
Model
(
bert_config
)
model
=
Bert
ForSequenceClassification
(
bert_config
)
if
args
.
init_checkpoint
is
not
None
:
if
args
.
init_checkpoint
is
not
None
:
model
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
.
to
(
device
)
model
.
to
(
device
)
optimizer
=
BERTAdam
([{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
n
!=
'bias'
],
'l2'
:
0.01
},
optimizer
=
BERTAdam
([{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
n
!=
'bias'
],
'l2'
:
0.01
},
{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
n
!
=
'bias'
]}
{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
n
=
=
'bias'
]
,
'l2'
:
0.
}
],
],
lr
=
args
.
learning_rate
,
schedule
=
'warmup_linear'
,
lr
=
args
.
learning_rate
,
schedule
=
'warmup_linear'
,
warmup
=
args
.
warmup_proportion
,
warmup
=
args
.
warmup_proportion
,
...
@@ -536,18 +503,31 @@ def main(_):
...
@@ -536,18 +503,31 @@ def main(_):
logger
.
info
(
" Num examples = %d"
,
len
(
train_examples
))
logger
.
info
(
" Num examples = %d"
,
len
(
train_examples
))
logger
.
info
(
" Batch size = %d"
,
args
.
train_batch_size
)
logger
.
info
(
" Batch size = %d"
,
args
.
train_batch_size
)
logger
.
info
(
" Num steps = %d"
,
num_train_steps
)
logger
.
info
(
" Num steps = %d"
,
num_train_steps
)
train_input
=
input_fn_builder
(
features
=
train_features
,
seq_length
=
args
.
max_seq_length
,
is_training
=
True
,
drop_remainder
=
True
)
# estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
for
batch_ix
,
batch
in
train_input
:
output
=
model_fn
(
batch
)
loss
=
output
[
"loss"
]
loss
.
backward
()
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
train_features
],
dtype
=
torch
.
Long
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
train_features
],
dtype
=
torch
.
Long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
train_features
],
dtype
=
torch
.
Long
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
train_features
],
dtype
=
torch
.
Long
)
train_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
if
args
.
local_rank
==
-
1
:
train_sampler
=
RandomSampler
(
train_data
)
else
:
train_sampler
=
DistributedSampler
(
train_data
)
train_dataloader
=
DataLoader
(
train_data
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
model
.
train
()
global_step
=
0
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
input_ids
.
to
(
device
)
input_mask
.
to
(
device
)
segment_ids
.
to
(
device
)
label_ids
.
to
(
device
)
loss
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
loss
.
backward
()
optimizer
.
step
()
global_step
+=
1
if
args
.
do_eval
:
if
args
.
do_eval
:
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
...
@@ -558,23 +538,40 @@ def main(_):
...
@@ -558,23 +538,40 @@ def main(_):
logger
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
logger
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
# This tells the estimator to run through the entire set.
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
Long
)
eval_steps
=
None
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
Long
)
# However, if running eval on the TPU, you will need to specify the
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
Long
)
# number of steps.
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
eval_features
],
dtype
=
torch
.
Long
)
if
args
.
use_tpu
:
# Eval will be slightly WRONG on the TPU because it will truncate
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
# the last batch.
if
args
.
local_rank
==
-
1
:
eval_steps
=
int
(
len
(
eval_examples
)
/
args
.
eval_batch_size
)
eval_sampler
=
SequentialSampler
(
eval_data
)
else
:
eval_sampler
=
DistributedSampler
(
eval_data
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
model
.
eval
()
eval_loss
=
0
eval_accuracy
=
0
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
eval_dataloader
:
input_ids
.
to
(
device
)
input_mask
.
to
(
device
)
segment_ids
.
to
(
device
)
label_ids
.
to
(
device
)
tmp_eval_loss
,
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
tmp_eval_accuracy
=
accuracy
(
logits
,
label_ids
)
eval_loss
+=
tmp_eval_loss
.
item
()
eval_accuracy
+=
tmp_eval_accuracy
eval_drop_remainder
=
True
if
args
.
use_tpu
else
False
eval_loss
=
eval_loss
/
len
(
eval_dataloader
)
eval_input_fn
=
input_fn_builder
(
eval_accuracy
=
eval_accuracy
/
len
(
eval_dataloader
)
features
=
eval_features
,
seq_length
=
args
.
max_seq_length
,
is_training
=
False
,
drop_remainder
=
eval_drop_remainder
)
result
=
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
eval_steps
)
result
=
{
'eval_loss'
:
eval_loss
,
'eval_accuracy'
:
eval_accuracy
,
'global_step'
:
global_step
,
'loss'
:
loss
.
item
()}
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
with
open
(
output_eval_file
,
"w"
)
as
writer
:
...
@@ -582,6 +579,6 @@ def main(_):
...
@@ -582,6 +579,6 @@ def main(_):
for
key
in
sorted
(
result
.
keys
()):
for
key
in
sorted
(
result
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
logger
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment