"...ldm/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "f50b1fec695cecc8f7c87ce1f39db3f6b49bb3a1"
Unverified Commit ba5ea66e authored by Malte Pietsch's avatar Malte Pietsch Committed by GitHub
Browse files

Fix tokenization in SQuAD for RoBERTa, Longformer, BART (#7387)

* fix squad tokenization for roberta & co

* change to pure type based check

* sort imports
parent 0270256b
...@@ -7,7 +7,10 @@ import numpy as np ...@@ -7,7 +7,10 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from ...file_utils import is_tf_available, is_torch_available from ...file_utils import is_tf_available, is_torch_available
from ...tokenization_bart import BartTokenizer
from ...tokenization_bert import whitespace_tokenize from ...tokenization_bert import whitespace_tokenize
from ...tokenization_longformer import LongformerTokenizer
from ...tokenization_roberta import RobertaTokenizer
from ...tokenization_utils_base import TruncationStrategy from ...tokenization_utils_base import TruncationStrategy
from ...utils import logging from ...utils import logging
from .utils import DataProcessor from .utils import DataProcessor
...@@ -109,6 +112,9 @@ def squad_convert_example_to_features( ...@@ -109,6 +112,9 @@ def squad_convert_example_to_features(
all_doc_tokens = [] all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens): for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens)) orig_to_tok_index.append(len(all_doc_tokens))
if isinstance(tokenizer, (RobertaTokenizer, LongformerTokenizer, BartTokenizer)):
sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
else:
sub_tokens = tokenizer.tokenize(token) sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens: for sub_token in sub_tokens:
tok_to_orig_index.append(i) tok_to_orig_index.append(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment