"docs/git@developer.sourcefind.cn:lacacy/qwen_lmdeploy.git" did not exist on "fbd9770a9b7c5bc2eb2dbf66b43e1c54ef165185"
Unverified Commit 2a85345a authored by ikkvix's avatar ikkvix Committed by GitHub
Browse files

Optimize the speed of the truncate_sequences function. (#28263)



* change truncate_sequences

* Update tokenization_utils_base.py

* change format

* fix when ids_to_move=0

* fix

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 66964c00
......@@ -3557,21 +3557,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"truncation strategy. So the returned list will always be empty even if some "
"tokens have been removed."
)
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
if self.truncation_side == "right":
ids = ids[:-1]
elif self.truncation_side == "left":
ids = ids[1:]
else:
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
else:
if self.truncation_side == "right":
pair_ids = pair_ids[:-1]
elif self.truncation_side == "left":
pair_ids = pair_ids[1:]
else:
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
len_pair_ids = len(pair_ids) if pair_ids is not None else 0
len_ids = len(ids)
first_remove = min(abs(len_pair_ids - len_ids), num_tokens_to_remove)
second_remove = num_tokens_to_remove - first_remove
if len_ids > len_pair_ids:
ids_to_move = first_remove + second_remove // 2
pair_ids_to_move = second_remove - second_remove // 2
else:
ids_to_move = second_remove // 2
pair_ids_to_move = first_remove + second_remove - (second_remove // 2)
if self.truncation_side == "right":
ids = ids[:-ids_to_move] if ids_to_move > 0 else ids
pair_ids = pair_ids[:-pair_ids_to_move] if pair_ids is not None and pair_ids_to_move > 0 else pair_ids
elif self.truncation_side == "left":
ids = ids[ids_to_move:]
pair_ids = pair_ids[pair_ids_to_move:] if pair_ids is not None else None
else:
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment