Unverified Commit 9de9ceb6 authored by Anirudh Srinivasan's avatar Anirudh Srinivasan Committed by GitHub
Browse files

Correct output shape for Bert NSP models in docs (#3482)

parent b815edf6
...@@ -845,7 +845,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -845,7 +845,7 @@ class BertForPreTraining(BertPreTrainedModel):
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss. Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax). continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
...@@ -1048,7 +1048,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -1048,7 +1048,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided): loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
Next sequence prediction (classification) loss. Next sequence prediction (classification) loss.
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`): seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment