Unverified Commit 1e05671d authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix QA example (#30580)

* Handle cases when CLS token is absent

* Use BOS token as a fallback
parent 4b4da18f
......@@ -434,7 +434,12 @@ def main():
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)
......
......@@ -417,7 +417,12 @@ def main():
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
......@@ -534,7 +539,12 @@ def main():
for i, input_ids in enumerate(tokenized_examples["input_ids"]):
# Find the CLS token in the input ids.
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
......
......@@ -444,7 +444,12 @@ def main():
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
......@@ -563,7 +568,12 @@ def main():
for i, input_ids in enumerate(tokenized_examples["input_ids"]):
# Find the CLS token in the input ids.
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
......
......@@ -513,7 +513,12 @@ def main():
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment