Commit d787c6be authored by thomwolf's avatar thomwolf
Browse files

improve docstrings and fix new token classification model

parent ed302a73
......@@ -559,7 +559,7 @@ class BertModel(PreTrainedBertModel):
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block,
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
......@@ -650,8 +650,8 @@ class BertForPreTraining(PreTrainedBertModel):
sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits, and
- the next sentence classification logits.
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
......@@ -680,7 +680,7 @@ class BertForPreTraining(PreTrainedBertModel):
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
......@@ -714,7 +714,7 @@ class BertForMaskedLM(PreTrainedBertModel):
if `masked_lm_labels` is `None`:
Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits.
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
Example usage:
```python
......@@ -776,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `next_sentence_label` is `None`:
Outputs the next sentence classification logits.
Outputs the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
......@@ -838,7 +838,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits.
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
......@@ -872,7 +872,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits
return loss
else:
return logits
......@@ -904,7 +904,7 @@ class BertForTokenClassification(PreTrainedBertModel):
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits.
Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
Example usage:
```python
......@@ -938,7 +938,7 @@ class BertForTokenClassification(PreTrainedBertModel):
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits
return loss
else:
return logits
......@@ -982,7 +982,7 @@ class BertForQuestionAnswering(PreTrainedBertModel):
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens.
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment