Unverified Commit 6a50d501 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

add explaining example to XLNet LM modeling (#2997)

* add explaining example to XLNet LM modeling

* improve docstring for xlnet
parent 65d74c49
......@@ -702,8 +702,9 @@ class XLNetModel(XLNetPreTrainedModel):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
`num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
......@@ -728,7 +729,7 @@ class XLNetModel(XLNetPreTrainedModel):
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetModel.from_pretrained('xlnet-large-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=False)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
......@@ -977,19 +978,21 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`, defaults to :obj:`None`):
Labels for masked language modeling.
`num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`.
The labels should correspond to the masked input words that should be predicted and depends on `target_mapping`. Note in order to perform standard auto-regressive language modeling a `<mask>` token has to be added to the `input_ids` (see `prepare_inputs_for_generation` fn and examples below)
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
All labels set to ``-100`` are ignored, the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
`num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
......@@ -1015,7 +1018,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased')
# We show how to setup inputs to predict a next token using a bi-directional context.
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=True)).unsqueeze(0) # We will predict the masked token
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)).unsqueeze(0) # We will predict the masked token
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token
......@@ -1024,6 +1027,18 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
# The same way can the XLNetLMHeadModel be used to be trained by standard auto-regressive language modeling.
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)).unsqueeze(0) # We will predict the masked token
labels = torch.tensor(tokenizer.encode("cute", add_special_tokens=False)).unsqueeze(0)
assert labels.shape[0] == 1, 'only one word will be predicted'
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token as is done in standard auto-regressive lm training
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token
target_mapping[0, 0, -1] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping, labels=labels)
loss, next_token_logits = outputs[:2] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
"""
transformer_outputs = self.transformer(
input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment