Commit 119610b5 authored by sshleifer's avatar sshleifer
Browse files

Merge branch 'master' into delete-n-special-doc

parents 08e4ad5e 0d1dad6d
......@@ -546,7 +546,7 @@ XLNET_INPUTS_DOCSTRING = r"""
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
@add_start_docstrings("The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
class XLNetModel(XLNetPreTrainedModel):
r"""
......@@ -743,8 +743,9 @@ class XLNetModel(XLNetPreTrainedModel):
if data_mask is not None:
# all mems can be attended to
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
data_mask = torch.cat([mems_mask, data_mask], dim=1)
if mlen > 0:
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
data_mask = torch.cat([mems_mask, data_mask], dim=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
else:
......@@ -755,7 +756,8 @@ class XLNetModel(XLNetPreTrainedModel):
if attn_mask is not None:
non_tgt_mask = -torch.eye(qlen).to(attn_mask)
non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
if mlen > 0:
non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
else:
non_tgt_mask = None
......@@ -775,8 +777,11 @@ class XLNetModel(XLNetPreTrainedModel):
##### Segment embedding
if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
if mlen > 0:
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
else:
cat_ids = token_type_ids
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
......@@ -1006,6 +1011,97 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return outputs # return (loss), logits, mems, (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
class XLNetForMultipleChoice(XLNetPreTrainedModel):
r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to scores.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = XLNetForMultipleChoice.from_pretrained('xlnet-base-cased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
def __init__(self, config):
super(XLNetForMultipleChoice, self).__init__(config)
self.transformer = XLNetModel(config)
self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, 1)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None,
labels=None, head_mask=None):
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
flat_input_mask = input_mask.view(-1, input_mask.size(-1)) if input_mask is not None else None
transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids,
input_mask=flat_input_mask, attention_mask=flat_attention_mask,
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
head_mask=head_mask)
output = transformer_outputs[0]
output = self.sequence_summary(output)
logits = self.logits_proj(output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
......@@ -1061,7 +1157,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
Examples::
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment