Commit 0a74c88a authored by thomwolf's avatar thomwolf
Browse files

fix #1131

parent 5f297c7b
......@@ -677,8 +677,11 @@ XLNET_INPUTS_DOCSTRING = r"""
``1`` for tokens that are MASKED, ``0`` for tokens that are NOT MASKED.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
that contains pre-computed hidden-states (key and values in the attention blocks) as output by the model
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
To activate mems you need to set up config.mem_len to a positive value which will be the max number of tokens in
the memory output by the model. E.g. `model = XLNetModel.from_pretrained('xlnet-base-case, mem_len=1024)` will
instantiate a model which can use up to 1024 tokens of memory (in addition to the input it self).
**perm_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, sequence_length)``:
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
......@@ -705,7 +708,8 @@ class XLNetModel(XLNetPreTrainedModel):
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
......@@ -859,7 +863,7 @@ class XLNetModel(XLNetPreTrainedModel):
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen
dtype_float = next(self.parameters()).dtype
......@@ -1011,7 +1015,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
......@@ -1091,7 +1096,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
......@@ -1189,7 +1195,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment