Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1e05671d
Unverified
Commit
1e05671d
authored
May 01, 2024
by
Matt
Committed by
GitHub
May 01, 2024
Browse files
Fix QA example (#30580)
* Handle cases when CLS token is absent * Use BOS token as a fallback
parent
4b4da18f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
6 deletions
+36
-6
examples/pytorch/question-answering/run_qa.py
examples/pytorch/question-answering/run_qa.py
+6
-1
examples/pytorch/question-answering/run_qa_beam_search.py
examples/pytorch/question-answering/run_qa_beam_search.py
+12
-2
examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
...torch/question-answering/run_qa_beam_search_no_trainer.py
+12
-2
examples/pytorch/question-answering/run_qa_no_trainer.py
examples/pytorch/question-answering/run_qa_no_trainer.py
+6
-1
No files found.
examples/pytorch/question-answering/run_qa.py
View file @
1e05671d
...
...
@@ -434,7 +434,12 @@ def main():
for
i
,
offsets
in
enumerate
(
offset_mapping
):
# We will label impossible answers with the index of the CLS token.
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids
=
tokenized_examples
.
sequence_ids
(
i
)
...
...
examples/pytorch/question-answering/run_qa_beam_search.py
View file @
1e05671d
...
...
@@ -417,7 +417,12 @@ def main():
for
i
,
offsets
in
enumerate
(
offset_mapping
):
# We will label impossible answers with the index of the CLS token.
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
...
...
@@ -534,7 +539,12 @@ def main():
for
i
,
input_ids
in
enumerate
(
tokenized_examples
[
"input_ids"
]):
# Find the CLS token in the input ids.
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
...
...
examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
View file @
1e05671d
...
...
@@ -444,7 +444,12 @@ def main():
for
i
,
offsets
in
enumerate
(
offset_mapping
):
# We will label impossible answers with the index of the CLS token.
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
...
...
@@ -563,7 +568,12 @@ def main():
for
i
,
input_ids
in
enumerate
(
tokenized_examples
[
"input_ids"
]):
# Find the CLS token in the input ids.
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
...
...
examples/pytorch/question-answering/run_qa_no_trainer.py
View file @
1e05671d
...
...
@@ -513,7 +513,12 @@ def main():
for
i
,
offsets
in
enumerate
(
offset_mapping
):
# We will label impossible answers with the index of the CLS token.
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids
=
tokenized_examples
.
sequence_ids
(
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment