Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
098a89f3
Commit
098a89f3
authored
Oct 29, 2019
by
Rémi Louf
Browse files
update docstrings; rename lm_labels to more explicit ltr_lm_labels
parent
dfce4096
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
27 deletions
+32
-27
examples/run_summarization_finetuning.py
examples/run_summarization_finetuning.py
+4
-4
transformers/modeling_bert.py
transformers/modeling_bert.py
+28
-23
No files found.
examples/run_summarization_finetuning.py
View file @
098a89f3
...
...
@@ -26,7 +26,7 @@ import numpy as np
from
tqdm
import
tqdm
,
trange
import
torch
from
torch.optim
import
Adam
from
torch.utils.data
import
Dataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data
import
DataLoader
,
RandomSampler
,
SequentialSampler
from
transformers
import
(
AutoTokenizer
,
...
...
@@ -283,14 +283,14 @@ def evaluate(args, model, tokenizer, prefix=""):
model
.
eval
()
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
source
,
target
,
encoder_token_type_ids
,
encoder_mask
,
decoder_mask
,
lm_labels
=
batch
source
,
target
,
encoder_token_type_ids
,
encoder_mask
,
decoder_mask
,
ltr_
lm_labels
=
batch
source
=
source
.
to
(
args
.
device
)
target
=
target
.
to
(
args
.
device
)
encoder_token_type_ids
=
encoder_token_type_ids
.
to
(
args
.
device
)
encoder_mask
=
encoder_mask
.
to
(
args
.
device
)
decoder_mask
=
decoder_mask
.
to
(
args
.
device
)
lm_labels
=
lm_labels
.
to
(
args
.
device
)
ltr_
lm_labels
=
ltr_
lm_labels
.
to
(
args
.
device
)
with
torch
.
no_grad
():
outputs
=
model
(
...
...
@@ -299,7 +299,7 @@ def evaluate(args, model, tokenizer, prefix=""):
encoder_token_type_ids
=
encoder_token_type_ids
,
encoder_attention_mask
=
encoder_mask
,
decoder_attention_mask
=
decoder_mask
,
decoder_lm_labels
=
lm_labels
,
decoder_
ltr_
lm_labels
=
ltr_
lm_labels
,
)
lm_loss
=
outputs
[
0
]
eval_loss
+=
lm_loss
.
mean
().
item
()
...
...
transformers/modeling_bert.py
View file @
098a89f3
...
...
@@ -548,6 +548,14 @@ BERT_INPUTS_DOCSTRING = r"""
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
is configured as a decoder.
**encoder_attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the model is configured as a decoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
"""
@
add_start_docstrings
(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top."
,
...
...
@@ -609,26 +617,18 @@ class BertModel(BertPreTrainedModel):
head_mask
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
):
""" Forward pass on the Model.
The values of the attention matrix (shape [batch_size, seq_length])
should be 1.0 for the position we want to attend to and 0. for the ones
we do not want to attend to.
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
ever self-attention layer, following the architecture described in [1].
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave
like
as a decoder the model needs to be initialized with the
`is_decoder` argument of the config set to `True`
. A
n
To behave as a
n
decoder the model needs to be initialized with the
`is_decoder` argument of the config
uration
set to `True`
; a
n
`encoder_hidden_states` is expected as an input to the forward pass.
When a decoder, there are two kinds of attention masks to specify:
(1) Self-attention masks that need to be causal (only attends to
previous tokens);
(2) A cross-attention mask that prevents the module
from attending to the encoder's padding tokens.
.. _`Attention is all you need`:
https://arxiv.org/abs/1706.03762
[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in
neural information processing systems. 2017.
"""
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
...
...
@@ -791,11 +791,16 @@ class BertForMaskedLM(BertPreTrainedModel):
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
**ltr_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the left-to-right language modeling loss (next word prediction).
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**
next_token
_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
**
ltr_lm
_loss**: (`optional`, returned when ``
ltr_
lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Next token prediction loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
...
...
@@ -833,7 +838,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
masked_lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
lm_labels
=
None
,
):
masked_lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
ltr_
lm_labels
=
None
,
):
outputs
=
self
.
bert
(
input_ids
,
attention_mask
=
attention_mask
,
...
...
@@ -852,22 +857,22 @@ class BertForMaskedLM(BertPreTrainedModel):
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If `lm_label` is provided we are in a causal scenario where we
# try to predict the next
word
for each input in the e
n
coder.
# 2. If `
ltr_
lm_label
s
` is provided we are in a causal scenario where we
# try to predict the next
token
for each input in the
d
ecoder.
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# -1 index = padding token
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
outputs
=
(
masked_lm_loss
,)
+
outputs
if
lm_labels
is
not
None
:
if
ltr_
lm_labels
is
not
None
:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:].
contiguous
()
lm_labels
=
lm_labels
[:,
1
:].
contiguous
()
ltr_
lm_labels
=
ltr_
lm_labels
[:,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
next_token
_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
lm_labels
.
view
(
-
1
))
outputs
=
(
next_token
_loss
,)
+
outputs
ltr_lm
_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
ltr_
lm_labels
.
view
(
-
1
))
outputs
=
(
ltr_lm
_loss
,)
+
outputs
return
outputs
# (masked_lm_loss), (
next_token
_loss), prediction_scores, (hidden_states), (attentions)
return
outputs
# (masked_lm_loss), (
ltr_lm
_loss), prediction_scores, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment