Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9343a231
Commit
9343a231
authored
Nov 02, 2018
by
thomwolf
Browse files
model training loop working – still have to check that everything is exactly same
parent
f690f0e1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
34 deletions
+37
-34
modeling_pytorch.py
modeling_pytorch.py
+19
-16
run_classifier_pytorch.py
run_classifier_pytorch.py
+18
-18
No files found.
modeling_pytorch.py
View file @
9343a231
...
@@ -18,21 +18,17 @@ from __future__ import absolute_import
...
@@ -18,21 +18,17 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
copy
import
copy
import
json
import
json
import
math
import
math
import
re
import
six
import
six
import
tensorflow
as
tf
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
def
gelu
(
x
):
def
gelu
(
x
):
raise
NotImplementedError
return
0.5
*
(
1.0
+
torch
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
# TF BERT says: cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
# OpenAI GPT gelu version was : 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
return
0.5
*
x
*
(
1
+
torch
.
tanh
(
math
.
sqrt
(
2
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))))
class
BertConfig
(
object
):
class
BertConfig
(
object
):
...
@@ -152,12 +148,11 @@ class BERTEmbeddings(nn.Module):
...
@@ -152,12 +148,11 @@ class BERTEmbeddings(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
):
batch_size
=
input_ids
.
size
(
0
)
seq_length
=
input_ids
.
size
(
1
)
seq_length
=
input_ids
.
size
(
1
)
# TODO finich that
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
torch
.
rang
e
().
view
(
batch_size
,
seq_length
)
position_ids
=
position_ids
.
unsqueez
e
(
0
).
expand_as
(
input_ids
)
if
token_type_ids
is
None
:
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
batch_size
,
seq_length
)
token_type_ids
=
torch
.
zeros
_like
(
input_ids
)
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
words_embeddings
=
self
.
word_embeddings
(
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
...
@@ -218,14 +213,14 @@ class BERTSelfAttention(nn.Module):
...
@@ -218,14 +213,14 @@ class BERTSelfAttention(nn.Module):
# TODO clean up this (precompute)
# TODO clean up this (precompute)
# MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# `attention_mask` = [B, 1, F, T]
# `attention_mask` = [B, 1, F, T]
attention_mask
=
tf
.
expand_dims
(
attention_mask
,
axis
=
[
1
])
#
attention_mask = tf.expand_dims(attention_mask, axis=[1])
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# positions we want to attend and -10000.0 for masked positions.
adder
=
(
1.0
-
attention_mask
)
*
-
10000.0
#
adder = (1.0 - attention_mask) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# effectively the same as removing these entirely.
attention_scores
+=
a
dder
attention_scores
+=
a
ttention_mask
# Normalize the attention scores to probabilities.
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
# `attention_probs` = [B, N, F, T]
...
@@ -289,7 +284,7 @@ class BERTOutput(nn.Module):
...
@@ -289,7 +284,7 @@ class BERTOutput(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
hidden_states
,
input_tensor
):
def
forward
(
self
,
hidden_states
,
input_tensor
):
hidden_states
=
self
.
dense
(
input_tensor
)
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
return
hidden_states
...
@@ -390,6 +385,14 @@ class BertModel(nn.Module):
...
@@ -390,6 +385,14 @@ class BertModel(nn.Module):
self
.
pooler
=
BERTPooler
(
config
)
self
.
pooler
=
BERTPooler
(
config
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
):
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
):
# We create 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, from_seq_length]
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
# It's more simple than the triangular masking of causal attention, just need to
# prepare the broadcast here
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
all_encoder_layers
=
self
.
encoder
(
embedding_output
,
attention_mask
)
all_encoder_layers
=
self
.
encoder
(
embedding_output
,
attention_mask
)
sequence_output
=
all_encoder_layers
[
-
1
]
sequence_output
=
all_encoder_layers
[
-
1
]
...
@@ -404,11 +407,11 @@ class BertForSequenceClassification(nn.Module):
...
@@ -404,11 +407,11 @@ class BertForSequenceClassification(nn.Module):
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
def
init_weights
(
m
):
def
init_weights
(
m
):
if
isinstance
(
m
)
==
nn
.
Linear
or
isinstance
(
m
)
==
nn
.
Embedding
:
if
isinstance
(
m
,
nn
.
Linear
)
or
isinstance
(
m
,
nn
.
Embedding
)
:
print
(
"Initializing {}"
.
format
(
m
))
print
(
"Initializing {}"
.
format
(
m
))
# Slight difference here with the TF version which uses truncated_normal
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
# cf https://github.com/pytorch/pytorch/pull/5617
m
.
weight
.
normal_
(
config
.
initializer_range
)
m
.
weight
.
data
.
normal_
(
config
.
initializer_range
)
self
.
apply
(
init_weights
)
self
.
apply
(
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
labels
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
labels
=
None
):
...
...
run_classifier_pytorch.py
View file @
9343a231
...
@@ -484,7 +484,7 @@ def main():
...
@@ -484,7 +484,7 @@ def main():
num_train_steps
=
int
(
num_train_steps
=
int
(
len
(
train_examples
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
len
(
train_examples
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
model
=
BertForSequenceClassification
(
bert_config
)
model
=
BertForSequenceClassification
(
bert_config
,
len
(
label_list
)
)
if
args
.
init_checkpoint
is
not
None
:
if
args
.
init_checkpoint
is
not
None
:
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
.
bert
.
load_state_dict
(
torch
.
load
(
args
.
init_checkpoint
,
map_location
=
'cpu'
))
model
.
to
(
device
)
model
.
to
(
device
)
...
@@ -504,10 +504,10 @@ def main():
...
@@ -504,10 +504,10 @@ def main():
logger
.
info
(
" Batch size = %d"
,
args
.
train_batch_size
)
logger
.
info
(
" Batch size = %d"
,
args
.
train_batch_size
)
logger
.
info
(
" Num steps = %d"
,
num_train_steps
)
logger
.
info
(
" Num steps = %d"
,
num_train_steps
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
train_features
],
dtype
=
torch
.
L
ong
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
train_features
],
dtype
=
torch
.
l
ong
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
train_features
],
dtype
=
torch
.
L
ong
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
train_features
],
dtype
=
torch
.
l
ong
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
train_features
],
dtype
=
torch
.
L
ong
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
train_features
],
dtype
=
torch
.
l
ong
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
train_features
],
dtype
=
torch
.
L
ong
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
train_features
],
dtype
=
torch
.
l
ong
)
train_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
train_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
if
args
.
local_rank
==
-
1
:
if
args
.
local_rank
==
-
1
:
...
@@ -519,12 +519,12 @@ def main():
...
@@ -519,12 +519,12 @@ def main():
model
.
train
()
model
.
train
()
global_step
=
0
global_step
=
0
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
input_ids
.
to
(
device
)
input_ids
=
input_ids
.
to
(
device
)
input_mask
.
to
(
device
)
input_mask
=
input_mask
.
float
()
.
to
(
device
)
segment_ids
.
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
label_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
loss
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
loss
,
_
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
global_step
+=
1
global_step
+=
1
...
@@ -538,10 +538,10 @@ def main():
...
@@ -538,10 +538,10 @@ def main():
logger
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
logger
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
L
ong
)
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
eval_features
],
dtype
=
torch
.
l
ong
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
L
ong
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
eval_features
],
dtype
=
torch
.
l
ong
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
L
ong
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
l
ong
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
eval_features
],
dtype
=
torch
.
L
ong
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
eval_features
],
dtype
=
torch
.
l
ong
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
if
args
.
local_rank
==
-
1
:
if
args
.
local_rank
==
-
1
:
...
@@ -554,10 +554,10 @@ def main():
...
@@ -554,10 +554,10 @@ def main():
eval_loss
=
0
eval_loss
=
0
eval_accuracy
=
0
eval_accuracy
=
0
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
eval_dataloader
:
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
eval_dataloader
:
input_ids
.
to
(
device
)
input_ids
=
input_ids
.
to
(
device
)
input_mask
.
to
(
device
)
input_mask
=
input_mask
.
float
()
.
to
(
device
)
segment_ids
.
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
label_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
tmp_eval_loss
,
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
tmp_eval_loss
,
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
tmp_eval_accuracy
=
accuracy
(
logits
,
label_ids
)
tmp_eval_accuracy
=
accuracy
(
logits
,
label_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment