Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
98edad41
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a4db4e303208a5b7b4ce1564301d77e9b74b01d9"
Commit
98edad41
authored
Jan 17, 2020
by
Lysandre
Committed by
Lysandre Debut
Jan 23, 2020
Browse files
PyTorch Transformer-XL
parent
96d21ad0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
116 additions
and
91 deletions
+116
-91
docs/source/model_doc/transformerxl.rst
docs/source/model_doc/transformerxl.rst
+25
-0
src/transformers/modeling_transfo_xl.py
src/transformers/modeling_transfo_xl.py
+91
-91
No files found.
docs/source/model_doc/transformerxl.rst
View file @
98edad41
Transformer XL
Transformer XL
----------------------------------------------------
----------------------------------------------------
The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
previously computed hidden-states to attend to longer context (memory).
This model also uses adaptive softmax inputs and outputs (tied).
The abstract from the paper is the following:
*Transformers have a potential of learning longer-term dependency, but are limited by a fixed-length context in the
setting of language modeling. We propose a novel neural architecture Transformer-XL that enables learning dependency
beyond a fixed length without disrupting temporal coherence. It consists of a segment-level recurrence mechanism and
a novel positional encoding scheme. Our method not only enables capturing longer-term dependency, but also resolves
the context fragmentation problem. As a result, Transformer-XL learns dependency that is 80% longer than RNNs and
450% longer than vanilla Transformers, achieves better performance on both short and long sequences, and is up
to 1,800+ times faster than vanilla Transformers during evaluation. Notably, we improve the state-of-the-art results
of bpc/perplexity to 0.99 on enwiki8, 1.08 on text8, 18.3 on WikiText-103, 21.8 on One Billion Word, and 54.5 on
Penn Treebank (without finetuning). When trained only on WikiText-103, Transformer-XL manages to generate reasonably
coherent, novel text articles with thousands of tokens.*
Tips:
- Transformer-XL uses relative sinusoidal positional embeddings so it's usually advised to pad the inputs on
the left rather than the right.
``TransfoXLConfig``
``TransfoXLConfig``
~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~
...
...
src/transformers/modeling_transfo_xl.py
View file @
98edad41
...
@@ -26,7 +26,7 @@ import torch.nn as nn
...
@@ -26,7 +26,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
from
.modeling_transfo_xl_utilities
import
LogUniformSampler
,
ProjectedAdaptiveLogSoftmax
,
sample_logits
from
.modeling_transfo_xl_utilities
import
LogUniformSampler
,
ProjectedAdaptiveLogSoftmax
,
sample_logits
from
.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
PreTrainedModel
...
@@ -508,21 +508,11 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
...
@@ -508,21 +508,11 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
self
.
_init_bias
(
m
.
r_bias
)
self
.
_init_bias
(
m
.
r_bias
)
TRANSFO_XL_START_DOCSTRING
=
r
""" The Transformer-XL model was proposed in
TRANSFO_XL_START_DOCSTRING
=
r
"""
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
previously computed hidden-states to attend to longer context (memory).
This model also uses adaptive softmax inputs and outputs (tied).
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
refer to the PyTorch documentation for all matter related to general usage and behavior.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
https://arxiv.org/abs/1901.02860
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
Parameters:
config (:class:`~transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
config (:class:`~transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
...
@@ -531,24 +521,25 @@ TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
...
@@ -531,24 +521,25 @@ TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
"""
"""
TRANSFO_XL_INPUTS_DOCSTRING
=
r
"""
TRANSFO_XL_INPUTS_DOCSTRING
=
r
"""
Inputs:
Args:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices of input sequence tokens in the vocabulary.
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`transformers.TransfoXLTokenizer`.
Indices can be obtained using :class:`transformers.TransfoXLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
`What are input IDs? <../glossary.html#input-ids>`__
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
given to this model should not be passed as input ids as they have already been computed.
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules.
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
`
`1`
`
indicates the head is **not masked**,
`
`0`
`
indicates the head is **masked**.
:obj:
`1` indicates the head is **not masked**,
:obj:
`0` indicates the head is **masked**.
**
input
s
_embeds
**: (`optional`) `
`torch.FloatTensor`
`
of shape
`
`(batch_size, sequence_length,
embedding_dim)``
:
input_embeds
(:obj:
`torch.FloatTensor` of shape
:obj:
`(batch_size, sequence_length,
hidden_size)`, `optional`, defaults to :obj:`None`)
:
Optionally, instead of passing
`
`input_ids`
`
you can choose to directly pass an embedded representation.
Optionally, instead of passing
:obj:
`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
than the model's internal embedding lookup matrix.
"""
"""
...
@@ -557,34 +548,8 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
...
@@ -557,34 +548,8 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
@
add_start_docstrings
(
@
add_start_docstrings
(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top."
,
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top."
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
,
)
)
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -705,7 +670,38 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -705,7 +670,38 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return
new_mems
return
new_mems
@
add_start_docstrings_to_callable
(
TRANSFO_XL_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
=
None
,
mems
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
):
def
forward
(
self
,
input_ids
=
None
,
mems
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
):
r
"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
"""
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
# so we transpose here from shape [bsz, len] to shape [len, bsz]
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
...
@@ -805,44 +801,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -805,44 +801,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
"""The Transformer-XL Model with a language modeling head on top
"""The Transformer-XL Model with a language modeling head on top
(adaptive softmax with weights tied to the adaptive input embeddings)"""
,
(adaptive softmax with weights tied to the adaptive input embeddings)"""
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
,
)
)
class
TransfoXLLMHeadModel
(
TransfoXLPreTrainedModel
):
class
TransfoXLLMHeadModel
(
TransfoXLPreTrainedModel
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``None`` if ``labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
We don't output them when the loss is computed to speedup adaptive softmax decoding.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
prediction_scores, mems = outputs[:2]
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -891,7 +851,47 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -891,7 +851,47 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def
init_mems
(
self
,
bsz
):
def
init_mems
(
self
,
bsz
):
return
self
.
transformer
.
init_mems
(
bsz
)
return
self
.
transformer
.
init_mems
(
bsz
)
@
add_start_docstrings_to_callable
(
TRANSFO_XL_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
=
None
,
mems
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
def
forward
(
self
,
input_ids
=
None
,
mems
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
r
"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.GPT2Config`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
prediction_scores, mems = outputs[:2]
"""
if
input_ids
is
not
None
:
if
input_ids
is
not
None
:
bsz
,
tgt_len
=
input_ids
.
size
(
0
),
input_ids
.
size
(
1
)
bsz
,
tgt_len
=
input_ids
.
size
(
0
),
input_ids
.
size
(
1
)
elif
inputs_embeds
is
not
None
:
elif
inputs_embeds
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment