# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """,
# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
# class XLNetForQuestionAnswering(XLNetPreTrainedModel):
# r"""
# **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
# Labels for position (index) of the start of the labelled span for computing the token classification loss.
# Positions are clamped to the length of the sequence (`sequence_length`).
# Position outside of the sequence are not taken into account for computing the loss.
# **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
# Labels for position (index) of the end of the labelled span for computing the token classification loss.
# Positions are clamped to the length of the sequence (`sequence_length`).
# Position outside of the sequence are not taken into account for computing the loss.
# **is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
# Labels whether a question has an answer or no answer (SQuAD 2.0)
# **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
# Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
# **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
# Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...).
# 1.0 means token should be masked. 0.0 mean token is not masked.
# Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
# **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
# Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
# **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
# Log probabilities for the top config.start_n_top start token possibilities (beam-search).
# **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
# Indices for the top config.start_n_top start token possibilities (beam-search).
# **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
# Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
# **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
# Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
# **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
# ``torch.FloatTensor`` of shape ``(batch_size,)``
# Log probabilities for the ``is_impossible`` label of the answers.
# **mems**:
# list of ``torch.FloatTensor`` (one for each layer):
# that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
# if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
# See details in the docstring of the `mems` input above.
# **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
# list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
# of shape ``(batch_size, sequence_length, hidden_size)``:
# Hidden-states of the model at the output of each layer plus the initial embedding outputs.
# **attentions**: (`optional`, returned when ``config.output_attentions=True``)
# list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
# start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states
# cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample