One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
Args:
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
One of ``start_states``, ``start_positions`` should be not None.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
If both are set, ``start_positions`` overrides ``start_states``.
`cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token.
**start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
# note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample
hidden states of the first tokens for the labeled span.
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the first token for the labeled span.
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
position of the CLS token. If None, take the last token.
note(Original repo):
no dependency on end_feature so that we can obtain one single `cls_logits`
for each sample
"""
"""
slen,hsz=hidden_states.shape[-2:]
slen,hsz=hidden_states.shape[-2:]
assertstart_statesisnotNoneorstart_positionsisnotNone,"One of start_states, start_positions should be not None"
assertstart_statesisnotNoneorstart_positionsisnotNone,"One of start_states, start_positions should be not None"
...
@@ -577,7 +591,35 @@ class PoolerAnswerClass(nn.Module):
...
@@ -577,7 +591,35 @@ class PoolerAnswerClass(nn.Module):
classSQuADHead(nn.Module):
classSQuADHead(nn.Module):
""" A SQuAD head inspired by XLNet.
r""" A SQuAD head inspired by XLNet.
Parameters:
config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
Inputs:
**hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
hidden states of sequence tokens
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the first token for the labeled span.
**end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
position of the last token for the labeled span.
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
position of the CLS token. If None, take the last token.
**is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
Whether the question has a possible answer in the paragraph or not.
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
1.0 means token should be masked.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.