Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d787c6be
Commit
d787c6be
authored
Nov 30, 2018
by
thomwolf
Browse files
improve docstrings and fix new token classification model
parent
ed302a73
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+11
-11
No files found.
pytorch_pretrained_bert/modeling.py
View file @
d787c6be
...
...
@@ -559,7 +559,7 @@ class BertModel(PreTrainedBertModel):
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block,
to the last attention block
of shape [batch_size, sequence_length, hidden_size]
,
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
...
...
@@ -650,8 +650,8 @@ class BertForPreTraining(PreTrainedBertModel):
sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits, and
- the next sentence classification logits.
- the masked language modeling logits
of shape [batch_size, sequence_length, vocab_size]
, and
- the next sentence classification logits
of shape [batch_size, 2]
.
Example usage:
```python
...
...
@@ -680,7 +680,7 @@ class BertForPreTraining(PreTrainedBertModel):
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
(
-
1
))
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
total_loss
=
masked_lm_loss
+
next_sentence_loss
return
total_loss
...
...
@@ -714,7 +714,7 @@ class BertForMaskedLM(PreTrainedBertModel):
if `masked_lm_labels` is `None`:
Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits.
Outputs the masked language modeling logits
of shape [batch_size, sequence_length, vocab_size]
.
Example usage:
```python
...
...
@@ -776,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `next_sentence_label` is `None`:
Outputs the next sentence classification logits.
Outputs the next sentence classification logits
of shape [batch_size, 2]
.
Example usage:
```python
...
...
@@ -838,7 +838,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits.
Outputs the classification logits
of shape [batch_size, num_labels]
.
Example usage:
```python
...
...
@@ -872,7 +872,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
,
logits
return
loss
else
:
return
logits
...
...
@@ -904,7 +904,7 @@ class BertForTokenClassification(PreTrainedBertModel):
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits.
Outputs the classification logits
of shape [batch_size, sequence_length, num_labels]
.
Example usage:
```python
...
...
@@ -938,7 +938,7 @@ class BertForTokenClassification(PreTrainedBertModel):
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
,
logits
return
loss
else
:
return
logits
...
...
@@ -982,7 +982,7 @@ class BertForQuestionAnswering(PreTrainedBertModel):
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens.
position tokens
of shape [batch_size, sequence_length]
.
Example usage:
```python
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment