Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9343a231
Commit
9343a231
authored
Nov 02, 2018
by
thomwolf
Browse files
model training loop working – still have to check that everything is exactly same
parent
f690f0e1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
34 deletions
+37
-34
modeling_pytorch.py
modeling_pytorch.py
+19
-16
run_classifier_pytorch.py
run_classifier_pytorch.py
+18
-18
No files found.
modeling_pytorch.py
View file @
9343a231
...
...
@@ -18,21 +18,17 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
copy
import
json
import
math
import
re
import
six
import
tensorflow
as
tf
import
torch
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
def
gelu
(
x
):
raise
NotImplementedError
# TF BERT says: cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
return
0.5
*
x
*
(
1
+
torch
.
tanh
(
math
.
sqrt
(
2
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))))
return
0.5
*
(
1.0
+
torch
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
# OpenAI GPT gelu version was : 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class
BertConfig
(
object
):
...
...
@@ -152,12 +148,11 @@ class BERTEmbeddings(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
):
batch_size
=
input_ids
.
size
(
0
)
seq_length
=
input_ids
.
size
(
1
)
# TODO finich that
position_ids
=
torch
.
rang
e
().
view
(
batch_size
,
seq_length
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueez
e
(
0
).
expand_as
(
input_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
batch_size
,
seq_length
)
token_type_ids
=
torch
.
zeros
_like
(
input_ids
)
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
...
...
@@ -218,14 +213,14 @@ class BERTSelfAttention(nn.Module):
# TODO clean up this (precompute)
# MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# `attention_mask` = [B, 1, F, T]
attention_mask
=
tf
.
expand_dims
(
attention_mask
,
axis
=
[
1
])
#
attention_mask = tf.expand_dims(attention_mask, axis=[1])
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder
=
(
1.0
-
attention_mask
)
*
-
10000.0
#
adder = (1.0 - attention_mask) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_scores
+=
a
dder
attention_scores
+=
a
ttention_mask
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
...
...
@@ -289,7 +284,7 @@ class BERTOutput(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
hidden_states
,
input_tensor
):
hidden_states
=
self
.
dense
(
input_tensor
)
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
...
...
@@ -390,6 +385,14 @@ class BertModel(nn.Module):
self
.
pooler
=
BERTPooler
(
config
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
):
# We create 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, from_seq_length]
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
# It's more simple than the triangular masking of causal attention, just need to
# prepare the broadcast here
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
all_encoder_layers
=
self
.
encoder
(
embedding_output
,
attention_mask
)
sequence_output
=
all_encoder_layers
[
-
1
]
...
...
@@ -404,11 +407,11 @@ class BertForSequenceClassification(nn.Module):
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
def
init_weights
(
m
):
if
isinstance
(
m
)
==
nn
.
Linear
or
isinstance
(
m
)
==
nn
.
Embedding
:
if
isinstance
(
m
,
nn
.
Linear
)
or
isinstance
(
m
,
nn
.
Embedding
)
:
print
(
"Initializing {}"
.
format
(
m
))
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
m
.
weight
.
normal_
(
config
.
initializer_range
)
m
.
weight
.
data
.
normal_
(
config
.
initializer_range
)
self
.
apply
(
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
labels
=
None
):
...
...
run_classifier_pytorch.py
View file @
9343a231
...
...
@@ -484,7 +484,7 @@ def main():
num_train_steps
=
int
(
len
(
train_examples
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
model
=
BertForSequenceClassification
(
bert_config
)
model
=
BertForSequenceClassification
(
bert_config
,
len
(
label_list
)
)
if
args
.
init_checkpoint
is
not
None
:
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
.
to
(
device
)
...
...
@@ -504,10 +504,10 @@ def main():
logger
.
info
(
" Batch size = %d"
,
args
.
train_batch_size
)
logger
.
info
(
" Num steps = %d"
,
num_train_steps
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
train_features
],
dtype
=
torch
.
L
ong
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
train_features
],
dtype
=
torch
.
L
ong
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
train_features
],
dtype
=
torch
.
L
ong
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
train_features
],
dtype
=
torch
.
L
ong
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
train_features
],
dtype
=
torch
.
l
ong
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
train_features
],
dtype
=
torch
.
l
ong
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
train_features
],
dtype
=
torch
.
l
ong
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
train_features
],
dtype
=
torch
.
l
ong
)
train_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
if
args
.
local_rank
==
-
1
:
...
...
@@ -519,12 +519,12 @@ def main():
model
.
train
()
global_step
=
0
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
input_ids
.
to
(
device
)
input_mask
.
to
(
device
)
segment_ids
.
to
(
device
)
label_ids
.
to
(
device
)
input_ids
=
input_ids
.
to
(
device
)
input_mask
=
input_mask
.
float
()
.
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
loss
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
loss
,
_
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
loss
.
backward
()
optimizer
.
step
()
global_step
+=
1
...
...
@@ -538,10 +538,10 @@ def main():
logger
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
L
ong
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
L
ong
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
L
ong
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
eval_features
],
dtype
=
torch
.
L
ong
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
l
ong
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
l
ong
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
l
ong
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
eval_features
],
dtype
=
torch
.
l
ong
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
if
args
.
local_rank
==
-
1
:
...
...
@@ -554,10 +554,10 @@ def main():
eval_loss
=
0
eval_accuracy
=
0
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
eval_dataloader
:
input_ids
.
to
(
device
)
input_mask
.
to
(
device
)
segment_ids
.
to
(
device
)
label_ids
.
to
(
device
)
input_ids
=
input_ids
.
to
(
device
)
input_mask
=
input_mask
.
float
()
.
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
tmp_eval_loss
,
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
tmp_eval_accuracy
=
accuracy
(
logits
,
label_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment