Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a725db4f
Commit
a725db4f
authored
Nov 05, 2018
by
thomwolf
Browse files
fixing BertForQuestionAnswering loss computation
parent
bb5ce67a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
29 deletions
+21
-29
modeling.py
modeling.py
+21
-29
No files found.
modeling.py
View file @
a725db4f
...
@@ -384,16 +384,16 @@ class BertForSequenceClassification(nn.Module):
...
@@ -384,16 +384,16 @@ class BertForSequenceClassification(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
def
init_weights
(
m
):
def
init_weights
(
m
odule
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Embedding
)):
if
isinstance
(
m
odule
,
(
nn
.
Linear
,
nn
.
Embedding
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
# cf https://github.com/pytorch/pytorch/pull/5617
m
.
weight
.
data
.
normal_
(
config
.
initializer_range
)
m
odule
.
weight
.
data
.
normal_
(
config
.
initializer_range
)
elif
isinstance
(
m
,
BERTLayerNorm
):
elif
isinstance
(
m
odule
,
BERTLayerNorm
):
m
.
beta
.
data
.
normal_
(
config
.
initializer_range
)
m
odule
.
beta
.
data
.
normal_
(
config
.
initializer_range
)
m
.
gamma
.
data
.
normal_
(
config
.
initializer_range
)
m
odule
.
gamma
.
data
.
normal_
(
config
.
initializer_range
)
if
isinstance
(
m
,
nn
.
Linear
):
if
isinstance
(
m
odule
,
nn
.
Linear
):
m
.
bias
.
data
.
zero_
()
m
odule
.
bias
.
data
.
zero_
()
self
.
apply
(
init_weights
)
self
.
apply
(
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
labels
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
labels
=
None
):
...
@@ -434,16 +434,16 @@ class BertForQuestionAnswering(nn.Module):
...
@@ -434,16 +434,16 @@ class BertForQuestionAnswering(nn.Module):
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
def
init_weights
(
m
):
def
init_weights
(
m
odule
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Embedding
)):
if
isinstance
(
m
odule
,
(
nn
.
Linear
,
nn
.
Embedding
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
# cf https://github.com/pytorch/pytorch/pull/5617
m
.
weight
.
data
.
normal_
(
config
.
initializer_range
)
m
odule
.
weight
.
data
.
normal_
(
config
.
initializer_range
)
elif
isinstance
(
m
,
BERTLayerNorm
):
elif
isinstance
(
m
odule
,
BERTLayerNorm
):
m
.
beta
.
data
.
normal_
(
config
.
initializer_range
)
m
odule
.
beta
.
data
.
normal_
(
config
.
initializer_range
)
m
.
gamma
.
data
.
normal_
(
config
.
initializer_range
)
m
odule
.
gamma
.
data
.
normal_
(
config
.
initializer_range
)
if
isinstance
(
m
,
nn
.
Linear
):
if
isinstance
(
m
odule
,
nn
.
Linear
):
m
.
bias
.
data
.
zero_
()
m
odule
.
bias
.
data
.
zero_
()
self
.
apply
(
init_weights
)
self
.
apply
(
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
start_positions
=
None
,
end_positions
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
start_positions
=
None
,
end_positions
=
None
):
...
@@ -451,21 +451,13 @@ class BertForQuestionAnswering(nn.Module):
...
@@ -451,21 +451,13 @@ class BertForQuestionAnswering(nn.Module):
sequence_output
=
all_encoder_layers
[
-
1
]
sequence_output
=
all_encoder_layers
[
-
1
]
logits
=
self
.
qa_outputs
(
sequence_output
)
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
size
()
loss_fct
=
CrossEntropyLoss
()
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
def
compute_loss
(
logits
,
positions
):
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
max_position
=
positions
.
max
().
item
()
one_hot
=
torch
.
FloatTensor
(
batch_size
,
max
(
max_position
,
seq_length
)
+
1
).
zero_
()
one_hot
=
one_hot
.
scatter_
(
1
,
positions
.
cpu
(),
1
)
# Do this on CPU
one_hot
=
one_hot
[:,
:
seq_length
].
to
(
input_ids
.
device
)
log_probs
=
nn
.
functional
.
log_softmax
(
logits
,
dim
=
-
1
).
view
(
batch_size
,
seq_length
)
loss
=
-
torch
.
mean
(
torch
.
sum
(
one_hot
*
log_probs
),
dim
=
-
1
)
return
loss
start_loss
=
compute_loss
(
start_logits
,
start_positions
)
end_loss
=
compute_loss
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
total_loss
=
(
start_loss
+
end_loss
)
/
2
return
total_loss
,
(
start_logits
,
end_logits
)
return
total_loss
,
(
start_logits
,
end_logits
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment