Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e93ccb32
"...guests/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "dcb3c95298acf54d6af2d09958286d51a72caa86"
Unverified
Commit
e93ccb32
authored
Jun 13, 2020
by
Suraj Patil
Committed by
GitHub
Jun 12, 2020
Browse files
BartForQuestionAnswering (#4908)
parent
538531cd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
146 additions
and
1 deletion
+146
-1
docs/source/model_doc/bart.rst
docs/source/model_doc/bart.rst
+7
-0
src/transformers/__init__.py
src/transformers/__init__.py
+1
-0
src/transformers/modeling_auto.py
src/transformers/modeling_auto.py
+7
-1
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+117
-0
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+14
-0
No files found.
docs/source/model_doc/bart.rst
View file @
e93ccb32
...
...
@@ -55,6 +55,13 @@ BartForSequenceClassification
:members: forward
BartForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForQuestionAnswering
:members: forward
BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
src/transformers/__init__.py
View file @
e93ccb32
...
...
@@ -250,6 +250,7 @@ if is_torch_available():
BartForSequenceClassification
,
BartModel
,
BartForConditionalGeneration
,
BartForQuestionAnswering
,
BART_PRETRAINED_MODEL_ARCHIVE_LIST
,
)
from
.modeling_marian
import
MarianMTModel
...
...
src/transformers/modeling_auto.py
View file @
e93ccb32
...
...
@@ -52,7 +52,12 @@ from .modeling_albert import (
AlbertForTokenClassification
,
AlbertModel
,
)
from
.modeling_bart
import
BartForConditionalGeneration
,
BartForSequenceClassification
,
BartModel
from
.modeling_bart
import
(
BartForConditionalGeneration
,
BartForQuestionAnswering
,
BartForSequenceClassification
,
BartModel
,
)
from
.modeling_bert
import
(
BertForMaskedLM
,
BertForMultipleChoice
,
...
...
@@ -274,6 +279,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
(
DistilBertConfig
,
DistilBertForQuestionAnswering
),
(
AlbertConfig
,
AlbertForQuestionAnswering
),
(
BartConfig
,
BartForQuestionAnswering
),
(
LongformerConfig
,
LongformerForQuestionAnswering
),
(
XLMRobertaConfig
,
XLMRobertaForQuestionAnswering
),
(
RobertaConfig
,
RobertaForQuestionAnswering
),
...
...
src/transformers/modeling_bart.py
View file @
e93ccb32
...
...
@@ -23,6 +23,7 @@ import numpy as np
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
from
torch.nn
import
CrossEntropyLoss
from
.activations
import
ACT2FN
from
.configuration_bart
import
BartConfig
...
...
@@ -1123,6 +1124,122 @@ class BartForSequenceClassification(PretrainedBartModel):
return
outputs
@
add_start_docstrings
(
"""BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer on top of
the hidden-states output to compute `span start logits` and `span end logits`). """
,
BART_START_DOCSTRING
,
)
class
BartForQuestionAnswering
(
PretrainedBartModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
config
.
num_labels
=
2
self
.
num_labels
=
config
.
num_labels
self
.
model
=
BartModel
(
config
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
model
.
_init_weights
(
self
.
qa_outputs
)
@
add_start_docstrings_to_callable
(
BART_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
encoder_outputs
=
None
,
decoder_input_ids
=
None
,
decoder_attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
output_attentions
=
None
,
):
r
"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
# The checkpoint bart-large is not fine-tuned for question answering. Please see the
# examples/question-answering/run_squad.py example to see how to fine-tune a model to a question answering task.
from transformers import BartTokenizer, BartForQuestionAnswering
import torch
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForQuestionAnswering.from_pretrained('facebook/bart-large')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_ids = tokenizer.encode(question, text)
start_scores, end_scores = model(torch.tensor([input_ids]))
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
"""
outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
encoder_outputs
=
encoder_outputs
,
output_attentions
=
output_attentions
,
)
sequence_output
=
outputs
[
0
]
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
(
start_logits
,
end_logits
,)
+
outputs
[
1
:]
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
start_positions
=
start_positions
.
squeeze
(
-
1
)
if
len
(
end_positions
.
size
())
>
1
:
end_positions
=
end_positions
.
squeeze
(
-
1
)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index
=
start_logits
.
size
(
1
)
start_positions
.
clamp_
(
0
,
ignored_index
)
end_positions
.
clamp_
(
0
,
ignored_index
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=
ignored_index
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
class
SinusoidalPositionalEmbedding
(
nn
.
Embedding
):
"""This module produces sinusoidal positional embeddings of any length."""
...
...
tests/test_modeling_bart.py
View file @
e93ccb32
...
...
@@ -35,6 +35,7 @@ if is_torch_available():
BartModel
,
BartForConditionalGeneration
,
BartForSequenceClassification
,
BartForQuestionAnswering
,
BartConfig
,
BartTokenizer
,
MBartTokenizer
,
...
...
@@ -375,6 +376,19 @@ class BartHeadTests(unittest.TestCase):
loss
=
outputs
[
0
]
self
.
assertIsInstance
(
loss
.
item
(),
float
)
def
test_question_answering_forward
(
self
):
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
()
sequence_labels
=
ids_tensor
([
batch_size
],
2
).
to
(
torch_device
)
model
=
BartForQuestionAnswering
(
config
)
model
.
to
(
torch_device
)
loss
,
start_logits
,
end_logits
,
_
=
model
(
input_ids
=
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
)
self
.
assertEqual
(
start_logits
.
shape
,
input_ids
.
shape
)
self
.
assertEqual
(
end_logits
.
shape
,
input_ids
.
shape
)
self
.
assertIsInstance
(
loss
.
item
(),
float
)
@
timeout_decorator
.
timeout
(
1
)
def
test_lm_forward
(
self
):
config
,
input_ids
,
batch_size
=
self
.
_get_config_and_data
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment