Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7e73601f
Unverified
Commit
7e73601f
authored
Jun 01, 2021
by
Fan Zhang
Committed by
GitHub
Jun 01, 2021
Browse files
modify qa-trainer (#11872)
* modify qa-trainer * fix flax model
parent
9ec0f01b
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
10 additions
and
10 deletions
+10
-10
src/transformers/models/reformer/modeling_reformer.py
src/transformers/models/reformer/modeling_reformer.py
+2
-2
src/transformers/models/roberta/modeling_roberta.py
src/transformers/models/roberta/modeling_roberta.py
+2
-2
src/transformers/models/squeezebert/modeling_squeezebert.py
src/transformers/models/squeezebert/modeling_squeezebert.py
+2
-2
src/transformers/models/xlm/modeling_xlm.py
src/transformers/models/xlm/modeling_xlm.py
+2
-2
src/transformers/models/xlnet/modeling_xlnet.py
src/transformers/models/xlnet/modeling_xlnet.py
+2
-2
No files found.
src/transformers/models/reformer/modeling_reformer.py
View file @
7e73601f
...
@@ -2555,8 +2555,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
...
@@ -2555,8 +2555,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/roberta/modeling_roberta.py
View file @
7e73601f
...
@@ -1472,8 +1472,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
...
@@ -1472,8 +1472,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/squeezebert/modeling_squeezebert.py
View file @
7e73601f
...
@@ -1068,8 +1068,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
...
@@ -1068,8 +1068,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/xlm/modeling_xlm.py
View file @
7e73601f
...
@@ -941,8 +941,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
...
@@ -941,8 +941,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
src/transformers/models/xlnet/modeling_xlnet.py
View file @
7e73601f
...
@@ -1862,8 +1862,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
...
@@ -1862,8 +1862,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
logits
=
self
.
qa_outputs
(
sequence_output
)
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
.
contiguous
()
end_logits
=
end_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
.
contiguous
()
total_loss
=
None
total_loss
=
None
if
start_positions
is
not
None
and
end_positions
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment