Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
213981d8
Commit
213981d8
authored
Jun 28, 2019
by
thomwolf
Browse files
updating bert API
parent
2b56e988
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
78 deletions
+72
-78
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+72
-78
No files found.
pytorch_pretrained_bert/modeling.py
View file @
213981d8
...
...
@@ -814,31 +814,28 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
BertForMaskedLM
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
cls
=
BertOnlyMLMHead
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
_
=
outputs
else
:
sequence_output
,
_
=
outputs
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
cls
(
sequence_output
)
outputs
=
[
prediction_scores
]
+
outputs
[
2
:]
# Add hidden states and attention is they are here
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
return
masked_lm_loss
elif
self
.
output_attentions
:
return
all_attentions
,
prediction_scores
return
prediction_scores
outputs
=
[
masked_lm_loss
]
+
outputs
return
outputs
# (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
class
BertForNextSentencePrediction
(
BertPreTrainedModel
):
...
...
@@ -889,31 +886,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
BertForNextSentencePrediction
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
cls
=
BertOnlyNSPHead
(
config
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
next_sentence_label
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
_
,
pooled_output
=
outputs
else
:
_
,
pooled_output
=
outputs
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
seq_relationship_score
=
self
.
cls
(
pooled_output
)
outputs
=
[
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
return
next_sentence_loss
elif
self
.
output_attentions
:
return
all_attentions
,
seq_relationship_score
return
seq_relationship_score
outputs
=
[
next_sentence_loss
]
+
outputs
return
outputs
# (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
class
BertForSequenceClassification
(
BertPreTrainedModel
):
...
...
@@ -966,25 +961,27 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
_
,
pooled_output
=
outputs
else
:
_
,
pooled_output
=
outputs
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
# We are doing regression
...
...
@@ -993,10 +990,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
elif
self
.
output_attentions
:
return
all_attentions
,
logits
return
logits
outputs
=
[
loss
]
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
class
BertForMultipleChoice
(
BertPreTrainedModel
):
...
...
@@ -1048,36 +1044,37 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_choices
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
num_choices
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
BertForMultipleChoice
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
num_choices
=
num_choices
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
outputs
=
self
.
bert
(
flat_input_ids
,
flat_token_type_ids
,
flat_attention_mask
,
output_all_encoded_layers
=
False
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
_
,
pooled_output
=
outputs
else
:
_
,
pooled_output
=
outputs
outputs
=
self
.
bert
(
flat_input_ids
,
flat_token_type_ids
,
flat_attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
self
.
num_choices
)
outputs
=
[
reshaped_logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
return
loss
elif
self
.
output_attentions
:
return
all_attentions
,
reshaped_logits
return
reshaped_logits
outputs
=
[
loss
]
+
outputs
return
outputs
# (loss), reshaped_logits, (hidden_states), (attentions)
class
BertForTokenClassification
(
BertPreTrainedModel
):
...
...
@@ -1130,25 +1127,26 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
num_labels
=
2
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
BertForTokenClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
_
=
outputs
else
:
sequence_output
,
_
=
outputs
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# Only keep active parts of the loss
...
...
@@ -1159,10 +1157,9 @@ class BertForTokenClassification(BertPreTrainedModel):
loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
elif
self
.
output_attentions
:
return
all_attentions
,
logits
return
logits
outputs
=
[
loss
]
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
class
BertForQuestionAnswering
(
BertPreTrainedModel
):
...
...
@@ -1217,28 +1214,26 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
BertForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
_
=
outputs
else
:
sequence_output
,
_
=
outputs
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
outputs
[
2
:]
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
...
...
@@ -1254,7 +1249,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
return
total_loss
elif
self
.
output_attentions
:
return
all_attentions
,
start_logits
,
end_logits
return
start_logits
,
end_logits
outputs
=
[
total_loss
]
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment