Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7e73601f
Unverified
Commit
7e73601f
authored
Jun 01, 2021
by
Fan Zhang
Committed by
GitHub
Jun 01, 2021
Browse files
modify qa-trainer (#11872)
* modify qa-trainer * fix flax model
parent
9ec0f01b
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
47 additions
and
39 deletions
+47
-39
examples/pytorch/question-answering/run_qa_no_trainer.py
examples/pytorch/question-answering/run_qa_no_trainer.py
+9
-1
src/transformers/models/albert/modeling_albert.py
src/transformers/models/albert/modeling_albert.py
+2
-2
src/transformers/models/bart/modeling_bart.py
src/transformers/models/bart/modeling_bart.py
+2
-2
src/transformers/models/bert/modeling_bert.py
src/transformers/models/bert/modeling_bert.py
+2
-2
src/transformers/models/big_bird/modeling_big_bird.py
src/transformers/models/big_bird/modeling_big_bird.py
+2
-2
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
...ormers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
+2
-2
src/transformers/models/convbert/modeling_convbert.py
src/transformers/models/convbert/modeling_convbert.py
+2
-2
src/transformers/models/deberta/modeling_deberta.py
src/transformers/models/deberta/modeling_deberta.py
+2
-2
src/transformers/models/deberta_v2/modeling_deberta_v2.py
src/transformers/models/deberta_v2/modeling_deberta_v2.py
+2
-2
src/transformers/models/distilbert/modeling_distilbert.py
src/transformers/models/distilbert/modeling_distilbert.py
+2
-2
src/transformers/models/dpr/modeling_dpr.py
src/transformers/models/dpr/modeling_dpr.py
+2
-2
src/transformers/models/electra/modeling_electra.py
src/transformers/models/electra/modeling_electra.py
+2
-2
src/transformers/models/funnel/modeling_funnel.py
src/transformers/models/funnel/modeling_funnel.py
+2
-2
src/transformers/models/ibert/modeling_ibert.py
src/transformers/models/ibert/modeling_ibert.py
+2
-2
src/transformers/models/led/modeling_led.py
src/transformers/models/led/modeling_led.py
+2
-2
src/transformers/models/longformer/modeling_longformer.py
src/transformers/models/longformer/modeling_longformer.py
+2
-2
src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mbart/modeling_mbart.py
+2
-2
src/transformers/models/megatron_bert/modeling_megatron_bert.py
...ansformers/models/megatron_bert/modeling_megatron_bert.py
+2
-2
src/transformers/models/mobilebert/modeling_mobilebert.py
src/transformers/models/mobilebert/modeling_mobilebert.py
+2
-2
src/transformers/models/mpnet/modeling_mpnet.py
src/transformers/models/mpnet/modeling_mpnet.py
+2
-2
No files found.
examples/pytorch/question-answering/run_qa_no_trainer.py
View file @
7e73601f
...
...
@@ -692,7 +692,11 @@ def main():
if
completed_steps
>=
args
.
max_train_steps
:
break
# Validation
# Evaluation
logger
.
info
(
"***** Running Evaluation *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
eval_dataset
)
}
"
)
logger
.
info
(
f
" Batch size =
{
args
.
per_device_eval_batch_size
}
"
)
all_start_logits
=
[]
all_end_logits
=
[]
for
step
,
batch
in
enumerate
(
eval_dataloader
):
...
...
@@ -725,6 +729,10 @@ def main():
# Prediction
if
args
.
do_predict
:
logger
.
info
(
"***** Running Prediction *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
predict_dataset
)
}
"
)
logger
.
info
(
f
" Batch size =
{
args
.
per_device_eval_batch_size
}
"
)
all_start_logits
=
[]
all_end_logits
=
[]
for
step
,
batch
in
enumerate
(
predict_dataloader
):
...
...
src/transformers/models/albert/modeling_albert.py
View file @
7e73601f
...
...
@@ -1218,8 +1218,8 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/bart/modeling_bart.py
View file @
7e73601f
...
...
@@ -1556,8 +1556,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/bert/modeling_bert.py
View file @
7e73601f
...
...
@@ -1801,8 +1801,8 @@ class BertForQuestionAnswering(BertPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/big_bird/modeling_big_bird.py
View file @
7e73601f
...
...
@@ -2983,8 +2983,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
logits
=
logits
-
logits_mask
*
1e6
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
View file @
7e73601f
...
...
@@ -2761,8 +2761,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/convbert/modeling_convbert.py
View file @
7e73601f
...
...
@@ -1293,8 +1293,8 @@ class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/deberta/modeling_deberta.py
View file @
7e73601f
...
...
@@ -1364,8 +1364,8 @@ class DebertaForQuestionAnswering(DebertaPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/deberta_v2/modeling_deberta_v2.py
View file @
7e73601f
...
...
@@ -1488,8 +1488,8 @@ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/distilbert/modeling_distilbert.py
View file @
7e73601f
...
...
@@ -728,8 +728,8 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
hidden_states
=
self
.
dropout
(
hidden_states
)
# (bs, max_query_len, dim)
logits
=
self
.
qa_outputs
(
hidden_states
)
# (bs, max_query_len, 2)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
# (bs, max_query_len)
end_logits
=
end_logits
.
squeeze
(
-
1
)
# (bs, max_query_len)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
# (bs, max_query_len)
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
# (bs, max_query_len)
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/dpr/modeling_dpr.py
View file @
7e73601f
...
...
@@ -241,8 +241,8 @@ class DPRSpanPredictor(PreTrainedModel):
# compute logits
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
relevance_logits
=
self
.
qa_classifier
(
sequence_output
[:,
0
,
:])
# resize
...
...
src/transformers/models/electra/modeling_electra.py
View file @
7e73601f
...
...
@@ -1318,8 +1318,8 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/funnel/modeling_funnel.py
View file @
7e73601f
...
...
@@ -1549,8 +1549,8 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel):
logits
=
self
.
qa_outputs
(
last_hidden_state
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/ibert/modeling_ibert.py
View file @
7e73601f
...
...
@@ -1319,8 +1319,8 @@ class IBertForQuestionAnswering(IBertPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/led/modeling_led.py
View file @
7e73601f
...
...
@@ -2585,8 +2585,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/longformer/modeling_longformer.py
View file @
7e73601f
...
...
@@ -2017,8 +2017,8 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/mbart/modeling_mbart.py
View file @
7e73601f
...
...
@@ -1563,8 +1563,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/megatron_bert/modeling_megatron_bert.py
View file @
7e73601f
...
...
@@ -1794,8 +1794,8 @@ class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/mobilebert/modeling_mobilebert.py
View file @
7e73601f
...
...
@@ -1371,8 +1371,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/mpnet/modeling_mpnet.py
View file @
7e73601f
...
...
@@ -1023,8 +1023,8 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment