Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0065af6
"tests/test_modeling_squeezebert.py" did not exist on "31c23bd5ee26425a67f92fc170789656379252a6"
Commit
c0065af6
authored
Nov 02, 2018
by
thomwolf
Browse files
implemented BertForQuestionAnswering
parent
5383fca4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
3 deletions
+52
-3
modeling_pytorch.py
modeling_pytorch.py
+52
-3
No files found.
modeling_pytorch.py
View file @
c0065af6
...
...
@@ -404,12 +404,12 @@ class BertForSequenceClassification(nn.Module):
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config =
modeling.
BertConfig(vocab_size=32000, hidden_size=512,
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
num_labels = 2
model =
modeling.BertModel
(config, num_labels)
model =
BertForSequenceClassification
(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
...
...
@@ -420,7 +420,7 @@ class BertForSequenceClassification(nn.Module):
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
def
init_weights
(
m
):
if
isinstance
(
m
,
nn
.
Linear
)
or
isinstance
(
m
,
nn
.
Embedding
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Embedding
)
)
:
print
(
"Initializing {}"
.
format
(
m
))
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
...
...
@@ -438,3 +438,52 @@ class BertForSequenceClassification(nn.Module):
return
loss
,
logits
else
:
return
logits
class
BertForQuestionAnswering
(
nn
.
Module
):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with linear layers on top of
the sequence output.
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = BertForQuestionAnswering(config)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
):
super
(
BertForQuestionAnswering
,
self
).
__init__
()
self
.
bert
=
BertModel
(
config
)
# TODO check if it's normal there is no dropout on SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
def
init_weights
(
m
):
if
isinstance
(
m
,
(
nn
.
Linear
,
nn
.
Embedding
)):
print
(
"Initializing {}"
.
format
(
m
))
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
m
.
weight
.
data
.
normal_
(
config
.
initializer_range
)
self
.
apply
(
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
,
attention_mask
,
start_positions
=
None
,
end_positions
=
None
):
all_encoder_layers
,
_
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
)
sequence_output
=
all_encoder_layers
[
-
1
]
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
return
total_loss
,
(
start_logits
,
end_logits
)
else
:
return
start_logits
,
end_logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment