Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
00df3d4d
"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "2de84ef830e5b573423cad9e0212add827a453d1"
Commit
00df3d4d
authored
Jan 15, 2020
by
Lysandre
Committed by
Lysandre Debut
Jan 23, 2020
Browse files
ALBERT Modeling + required changes to utilities
parent
f81b6c95
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
259 additions
and
166 deletions
+259
-166
docs/source/model_doc/albert.rst
docs/source/model_doc/albert.rst
+38
-9
src/transformers/file_utils.py
src/transformers/file_utils.py
+19
-1
src/transformers/modeling_albert.py
src/transformers/modeling_albert.py
+180
-150
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+22
-6
No files found.
docs/source/model_doc/albert.rst
View file @
00df3d4d
ALBERT
ALBERT
----------------------------------------------------
----------------------------------------------------
``AlbertConfig``
Overview
~~~~~~~~~~~~~~~~~~~~~
The ALBERT model was proposed in `ALBERT: A Lite BERT for Self-supervised Learning of Language Representations <https://arxiv.org/abs/1909.11942>`_
by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. It presents
two parameter-reduction techniques to lower memory consumption and increase the trainig speed of BERT:
- Splitting the embedding matrix into two smaller matrices
- Using repeating layers split among groups
The abstract from the paper is the following:
*Increasing model size when pretraining natural language representations often results in improved performance on
downstream tasks. However, at some point further model increases become harder due to GPU/TPU memory limitations,
longer training times, and unexpected model degradation. To address these problems, we present two parameter-reduction
techniques to lower memory consumption and increase the training speed of BERT. Comprehensive empirical evidence shows
that our proposed methods lead to models that scale much better compared to the original BERT. We also use a
self-supervised loss that focuses on modeling inter-sentence coherence, and show it consistently helps downstream
tasks with multi-sentence inputs. As a result, our best model establishes new state-of-the-art results on the GLUE,
RACE, and SQuAD benchmarks while having fewer parameters compared to BERT-large.*
Tips:
- ALBERT is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
- ALBERT uses repeating layers which results in a small memory footprint, however the computational cost remains
similar to a BERT-like architecture with the same number of hidden layers as it has to iterate through the same
number of (repeating) layers.
AlbertConfig
~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AlbertConfig
.. autoclass:: transformers.AlbertConfig
:members:
:members:
``
AlbertTokenizer
``
AlbertTokenizer
~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AlbertTokenizer
.. autoclass:: transformers.AlbertTokenizer
:members:
:members:
``
AlbertModel
``
AlbertModel
~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AlbertModel
.. autoclass:: transformers.AlbertModel
:members:
:members:
``
AlbertForMaskedLM
``
AlbertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AlbertForMaskedLM
.. autoclass:: transformers.AlbertForMaskedLM
:members:
:members:
``
AlbertForSequenceClassification
``
AlbertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AlbertForSequenceClassification
.. autoclass:: transformers.AlbertForSequenceClassification
:members:
:members:
``
AlbertForQuestionAnswering
``
AlbertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AlbertForQuestionAnswering
.. autoclass:: transformers.AlbertForQuestionAnswering
:members:
:members:
``
TFAlbertModel
``
TFAlbertModel
~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAlbertModel
.. autoclass:: transformers.TFAlbertModel
:members:
:members:
``
TFAlbertForMaskedLM
``
TFAlbertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAlbertForMaskedLM
.. autoclass:: transformers.TFAlbertForMaskedLM
:members:
:members:
``
TFAlbertForSequenceClassification
``
TFAlbertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAlbertForSequenceClassification
.. autoclass:: transformers.TFAlbertForSequenceClassification
...
...
src/transformers/file_utils.py
View file @
00df3d4d
...
@@ -105,7 +105,25 @@ def is_tf_available():
...
@@ -105,7 +105,25 @@ def is_tf_available():
def
add_start_docstrings
(
*
docstr
):
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
""
.
join
(
docstr
)
+
fn
.
__doc__
fn
.
__doc__
=
""
.
join
(
docstr
)
+
(
fn
.
__doc__
if
fn
.
__doc__
is
not
None
else
""
)
return
fn
return
docstring_decorator
def
add_start_docstrings_to_callable
(
*
docstr
):
def
docstring_decorator
(
fn
):
class_name
=
":class:`~transformers.{}`"
.
format
(
fn
.
__qualname__
.
split
(
"."
)[
0
])
intro
=
" The {} forward method, overrides the :func:`__call__` special method."
.
format
(
class_name
)
note
=
r
"""
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
fn
.
__doc__
=
intro
+
note
+
""
.
join
(
docstr
)
+
(
fn
.
__doc__
if
fn
.
__doc__
is
not
None
else
""
)
return
fn
return
fn
return
docstring_decorator
return
docstring_decorator
...
...
src/transformers/modeling_albert.py
View file @
00df3d4d
...
@@ -26,7 +26,7 @@ from transformers.configuration_albert import AlbertConfig
...
@@ -26,7 +26,7 @@ from transformers.configuration_albert import AlbertConfig
from
transformers.modeling_bert
import
ACT2FN
,
BertEmbeddings
,
BertSelfAttention
,
prune_linear_layer
from
transformers.modeling_bert
import
ACT2FN
,
BertEmbeddings
,
BertSelfAttention
,
prune_linear_layer
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
,
add_start_docstrings_to_callable
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -376,11 +376,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
...
@@ -376,11 +376,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
module
.
weight
.
data
.
fill_
(
1.0
)
module
.
weight
.
data
.
fill_
(
1.0
)
ALBERT_START_DOCSTRING
=
r
""" The ALBERT model was proposed in
ALBERT_START_DOCSTRING
=
r
"""
`ALBERT: A Lite BERT for Self-supervised Learning of Language Representations`_
by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. It presents
two parameter-reduction techniques to lower memory consumption and increase the trainig speed of BERT.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
refer to the PyTorch documentation for all matter related to general usage and behavior.
...
@@ -390,80 +386,51 @@ ALBERT_START_DOCSTRING = r""" The ALBERT model was proposed in
...
@@ -390,80 +386,51 @@ ALBERT_START_DOCSTRING = r""" The ALBERT model was proposed in
.. _`torch.nn.Module`:
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
https://pytorch.org/docs/stable/nn.html#module
Parameter
s:
Arg
s:
config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
"""
ALBERT_INPUTS_DOCSTRING
=
r
"""
ALBERT_INPUTS_DOCSTRING
=
r
"""
Inputs:
Args:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices of input sequence tokens in the vocabulary.
To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0``
Albert is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`transformers.AlbertTokenizer`.
Indices can be obtained using :class:`transformers.AlbertTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs.
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
corresponds to a `sentence B` token
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings.
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules.
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
`
`1`
`
indicates the head is **not masked**,
`
`0`
`
indicates the head is **masked**.
:obj:
`1` indicates the head is **not masked**,
:obj:
`0` indicates the head is **masked**.
"""
"""
@
add_start_docstrings
(
@
add_start_docstrings
(
"The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top."
,
"The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top."
,
ALBERT_START_DOCSTRING
,
ALBERT_START_DOCSTRING
,
ALBERT_INPUTS_DOCSTRING
,
)
)
class
AlbertModel
(
AlbertPreTrainedModel
):
class
AlbertModel
(
AlbertPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
"""
config_class
=
AlbertConfig
config_class
=
AlbertConfig
pretrained_model_archive_map
=
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
...
@@ -520,6 +487,44 @@ class AlbertModel(AlbertPreTrainedModel):
...
@@ -520,6 +487,44 @@ class AlbertModel(AlbertPreTrainedModel):
head_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
inputs_embeds
=
None
,
):
):
r
"""
Return:
:obj:`Tuple` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Example::
from transformers import AlbertModel, AlbertTokenizer
import torch
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertModel.from_pretrained('albert-base-v2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
...
@@ -594,29 +599,9 @@ class AlbertMLMHead(nn.Module):
...
@@ -594,29 +599,9 @@ class AlbertMLMHead(nn.Module):
@
add_start_docstrings
(
@
add_start_docstrings
(
"
B
ert Model with a `language modeling` head on top."
,
ALBERT_START_DOCSTRING
,
ALBERT_INPUTS_DOCSTRING
"
Alb
ert Model with a `language modeling` head on top."
,
ALBERT_START_DOCSTRING
,
)
)
class
AlbertForMaskedLM
(
AlbertPreTrainedModel
):
class
AlbertForMaskedLM
(
AlbertPreTrainedModel
):
r
"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -628,14 +613,12 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
...
@@ -628,14 +613,12 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
self
.
tie_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
def
tie_weights
(
self
):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self
.
_tie_or_clone_weights
(
self
.
predictions
.
decoder
,
self
.
albert
.
embeddings
.
word_embeddings
)
self
.
_tie_or_clone_weights
(
self
.
predictions
.
decoder
,
self
.
albert
.
embeddings
.
word_embeddings
)
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
predictions
.
decoder
return
self
.
predictions
.
decoder
@
add_start_docstrings_to_callable
(
ALBERT_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -646,6 +629,43 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
...
@@ -646,6 +629,43 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
inputs_embeds
=
None
,
inputs_embeds
=
None
,
masked_lm_labels
=
None
,
masked_lm_labels
=
None
,
):
):
r
"""
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with
labels in ``[0, ..., config.vocab_size]``
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Example::
from transformers import AlbertTokenizer, AlbertForMaskedLM
import torch
tokenizer = BertTokenizer.from_pretrained('albert-base-v2')
model = BertForMaskedLM.from_pretrained('albert-base-v2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
"""
outputs
=
self
.
albert
(
outputs
=
self
.
albert
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -671,39 +691,8 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
...
@@ -671,39 +691,8 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
"""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
"""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """
,
the pooled output) e.g. for GLUE tasks. """
,
ALBERT_START_DOCSTRING
,
ALBERT_START_DOCSTRING
,
ALBERT_INPUTS_DOCSTRING
,
)
)
class
AlbertForSequenceClassification
(
AlbertPreTrainedModel
):
class
AlbertForSequenceClassification
(
AlbertPreTrainedModel
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForSequenceClassification.from_pretrained('albert-base-v2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -715,6 +704,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
...
@@ -715,6 +704,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
add_start_docstrings_to_callable
(
ALBERT_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -725,6 +715,44 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
...
@@ -725,6 +715,44 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
inputs_embeds
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
labels
=
None
,
):
):
r
"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
logits ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import AlbertTokenizer, AlbertForSequenceClassification
import torch
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForSequenceClassification.from_pretrained('albert-base-v2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
outputs
=
self
.
albert
(
outputs
=
self
.
albert
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -759,49 +787,8 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
...
@@ -759,49 +787,8 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """
,
the hidden-states output to compute `span start logits` and `span end logits`). """
,
ALBERT_START_DOCSTRING
,
ALBERT_START_DOCSTRING
,
ALBERT_INPUTS_DOCSTRING
,
)
)
class
AlbertForQuestionAnswering
(
AlbertPreTrainedModel
):
class
AlbertForQuestionAnswering
(
AlbertPreTrainedModel
):
r
"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForQuestionAnswering.from_pretrained('albert-base-v2')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_text = "[CLS] " + question + " [SEP] " + text + " [SEP]"
input_ids = tokenizer.encode(input_text)
token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))]
start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
# a nice puppet
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -812,6 +799,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
...
@@ -812,6 +799,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
add_start_docstrings_to_callable
(
ALBERT_INPUTS_DOCSTRING
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -823,6 +811,48 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
...
@@ -823,6 +811,48 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
start_positions
=
None
,
start_positions
=
None
,
end_positions
=
None
,
end_positions
=
None
,
):
):
r
"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
end_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
# The checkpoint albert-base-v2 is not fine-tuned for question answering. Please see the
# examples/run_squad.py example to see how to fine-tune a model to a question answering task.
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForQuestionAnswering.from_pretrained('albert-base-v2')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_dict = tokenizer.encode_plus(question, text, return_tensors='pt')
start_scores, end_scores = model(**input_dict)
"""
outputs
=
self
.
albert
(
outputs
=
self
.
albert
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
...
src/transformers/modeling_utils.py
View file @
00df3d4d
...
@@ -114,7 +114,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -114,7 +114,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return
getattr
(
self
,
self
.
base_model_prefix
,
self
)
return
getattr
(
self
,
self
.
base_model_prefix
,
self
)
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
""" Get model's input embeddings
"""
Returns the model's input embeddings.
Returns:
:obj:`nn.Module`:
A torch module mapping vocabulary to hidden states.
"""
"""
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
if
base_model
is
not
self
:
if
base_model
is
not
self
:
...
@@ -123,7 +128,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -123,7 +128,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
raise
NotImplementedError
raise
NotImplementedError
def
set_input_embeddings
(
self
,
value
):
def
set_input_embeddings
(
self
,
value
):
""" Set model's input embeddings
"""
Set model's input embeddings
Args:
value (:obj:`nn.Module`):
A module mapping vocabulary to hidden states.
"""
"""
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
if
base_model
is
not
self
:
if
base_model
is
not
self
:
...
@@ -132,14 +142,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -132,14 +142,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
raise
NotImplementedError
raise
NotImplementedError
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
""" Get model's output embeddings
"""
Return None if the model doesn't have output embeddings
Returns the model's output embeddings.
Returns:
:obj:`nn.Module`:
A torch module mapping hidden states to vocabulary.
"""
"""
return
None
# Overwrite for models with output embeddings
return
None
# Overwrite for models with output embeddings
def
tie_weights
(
self
):
def
tie_weights
(
self
):
""" Make sure we are sharing the input and output embeddings.
"""
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
Tie the weights between the input embeddings and the output embeddings.
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
the weights instead.
"""
"""
output_embeddings
=
self
.
get_output_embeddings
()
output_embeddings
=
self
.
get_output_embeddings
()
if
output_embeddings
is
not
None
:
if
output_embeddings
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment