Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1e05671d
Unverified
Commit
1e05671d
authored
May 01, 2024
by
Matt
Committed by
GitHub
May 01, 2024
Browse files
Fix QA example (#30580)
* Handle cases when CLS token is absent * Use BOS token as a fallback
parent
4b4da18f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
6 deletions
+36
-6
examples/pytorch/question-answering/run_qa.py
examples/pytorch/question-answering/run_qa.py
+6
-1
examples/pytorch/question-answering/run_qa_beam_search.py
examples/pytorch/question-answering/run_qa_beam_search.py
+12
-2
examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
...torch/question-answering/run_qa_beam_search_no_trainer.py
+12
-2
examples/pytorch/question-answering/run_qa_no_trainer.py
examples/pytorch/question-answering/run_qa_no_trainer.py
+6
-1
No files found.
examples/pytorch/question-answering/run_qa.py
View file @
1e05671d
...
@@ -434,7 +434,12 @@ def main():
...
@@ -434,7 +434,12 @@ def main():
for
i
,
offsets
in
enumerate
(
offset_mapping
):
for
i
,
offsets
in
enumerate
(
offset_mapping
):
# We will label impossible answers with the index of the CLS token.
# We will label impossible answers with the index of the CLS token.
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids
=
tokenized_examples
.
sequence_ids
(
i
)
sequence_ids
=
tokenized_examples
.
sequence_ids
(
i
)
...
...
examples/pytorch/question-answering/run_qa_beam_search.py
View file @
1e05671d
...
@@ -417,7 +417,12 @@ def main():
...
@@ -417,7 +417,12 @@ def main():
for
i
,
offsets
in
enumerate
(
offset_mapping
):
for
i
,
offsets
in
enumerate
(
offset_mapping
):
# We will label impossible answers with the index of the CLS token.
# We will label impossible answers with the index of the CLS token.
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
...
@@ -534,7 +539,12 @@ def main():
...
@@ -534,7 +539,12 @@ def main():
for
i
,
input_ids
in
enumerate
(
tokenized_examples
[
"input_ids"
]):
for
i
,
input_ids
in
enumerate
(
tokenized_examples
[
"input_ids"
]):
# Find the CLS token in the input ids.
# Find the CLS token in the input ids.
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
...
...
examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
View file @
1e05671d
...
@@ -444,7 +444,12 @@ def main():
...
@@ -444,7 +444,12 @@ def main():
for
i
,
offsets
in
enumerate
(
offset_mapping
):
for
i
,
offsets
in
enumerate
(
offset_mapping
):
# We will label impossible answers with the index of the CLS token.
# We will label impossible answers with the index of the CLS token.
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
...
@@ -563,7 +568,12 @@ def main():
...
@@ -563,7 +568,12 @@ def main():
for
i
,
input_ids
in
enumerate
(
tokenized_examples
[
"input_ids"
]):
for
i
,
input_ids
in
enumerate
(
tokenized_examples
[
"input_ids"
]):
# Find the CLS token in the input ids.
# Find the CLS token in the input ids.
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
tokenized_examples
[
"cls_index"
].
append
(
cls_index
)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
...
...
examples/pytorch/question-answering/run_qa_no_trainer.py
View file @
1e05671d
...
@@ -513,7 +513,12 @@ def main():
...
@@ -513,7 +513,12 @@ def main():
for
i
,
offsets
in
enumerate
(
offset_mapping
):
for
i
,
offsets
in
enumerate
(
offset_mapping
):
# We will label impossible answers with the index of the CLS token.
# We will label impossible answers with the index of the CLS token.
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
input_ids
=
tokenized_examples
[
"input_ids"
][
i
]
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
if
tokenizer
.
cls_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
cls_token_id
)
elif
tokenizer
.
bos_token_id
in
input_ids
:
cls_index
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
else
:
cls_index
=
0
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids
=
tokenized_examples
.
sequence_ids
(
i
)
sequence_ids
=
tokenized_examples
.
sequence_ids
(
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment