Commit d787c6be authored by thomwolf's avatar thomwolf
Browse files

improve docstrings and fix new token classification model

parent ed302a73
...@@ -559,7 +559,7 @@ class BertModel(PreTrainedBertModel): ...@@ -559,7 +559,7 @@ class BertModel(PreTrainedBertModel):
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block, to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper). input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
...@@ -650,8 +650,8 @@ class BertForPreTraining(PreTrainedBertModel): ...@@ -650,8 +650,8 @@ class BertForPreTraining(PreTrainedBertModel):
sentence classification loss. sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`: if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising Outputs a tuple comprising
- the masked language modeling logits, and - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits. - the next sentence classification logits of shape [batch_size, 2].
Example usage: Example usage:
```python ```python
...@@ -680,7 +680,7 @@ class BertForPreTraining(PreTrainedBertModel): ...@@ -680,7 +680,7 @@ class BertForPreTraining(PreTrainedBertModel):
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss total_loss = masked_lm_loss + next_sentence_loss
return total_loss return total_loss
...@@ -714,7 +714,7 @@ class BertForMaskedLM(PreTrainedBertModel): ...@@ -714,7 +714,7 @@ class BertForMaskedLM(PreTrainedBertModel):
if `masked_lm_labels` is `None`: if `masked_lm_labels` is `None`:
Outputs the masked language modeling loss. Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`: if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits. Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
Example usage: Example usage:
```python ```python
...@@ -776,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel): ...@@ -776,7 +776,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
Outputs the total_loss which is the sum of the masked language modeling loss and the next Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss. sentence classification loss.
if `next_sentence_label` is `None`: if `next_sentence_label` is `None`:
Outputs the next sentence classification logits. Outputs the next sentence classification logits of shape [batch_size, 2].
Example usage: Example usage:
```python ```python
...@@ -838,7 +838,7 @@ class BertForSequenceClassification(PreTrainedBertModel): ...@@ -838,7 +838,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if `labels` is not `None`: if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels. Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`: if `labels` is `None`:
Outputs the classification logits. Outputs the classification logits of shape [batch_size, num_labels].
Example usage: Example usage:
```python ```python
...@@ -872,7 +872,7 @@ class BertForSequenceClassification(PreTrainedBertModel): ...@@ -872,7 +872,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits return loss
else: else:
return logits return logits
...@@ -904,7 +904,7 @@ class BertForTokenClassification(PreTrainedBertModel): ...@@ -904,7 +904,7 @@ class BertForTokenClassification(PreTrainedBertModel):
if `labels` is not `None`: if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels. Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`: if `labels` is `None`:
Outputs the classification logits. Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
Example usage: Example usage:
```python ```python
...@@ -938,7 +938,7 @@ class BertForTokenClassification(PreTrainedBertModel): ...@@ -938,7 +938,7 @@ class BertForTokenClassification(PreTrainedBertModel):
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits return loss
else: else:
return logits return logits
...@@ -982,7 +982,7 @@ class BertForQuestionAnswering(PreTrainedBertModel): ...@@ -982,7 +982,7 @@ class BertForQuestionAnswering(PreTrainedBertModel):
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`: if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens. position tokens of shape [batch_size, sequence_length].
Example usage: Example usage:
```python ```python
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment