Unverified Commit 2a85345a authored by ikkvix's avatar ikkvix Committed by GitHub
Browse files

Optimize the speed of the truncate_sequences function. (#28263)



* change truncate_sequences

* Update tokenization_utils_base.py

* change format

* fix when ids_to_move=0

* fix

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 66964c00
...@@ -3557,21 +3557,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -3557,21 +3557,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"truncation strategy. So the returned list will always be empty even if some " "truncation strategy. So the returned list will always be empty even if some "
"tokens have been removed." "tokens have been removed."
) )
for _ in range(num_tokens_to_remove): len_pair_ids = len(pair_ids) if pair_ids is not None else 0
if pair_ids is None or len(ids) > len(pair_ids): len_ids = len(ids)
if self.truncation_side == "right": first_remove = min(abs(len_pair_ids - len_ids), num_tokens_to_remove)
ids = ids[:-1] second_remove = num_tokens_to_remove - first_remove
elif self.truncation_side == "left": if len_ids > len_pair_ids:
ids = ids[1:] ids_to_move = first_remove + second_remove // 2
else: pair_ids_to_move = second_remove - second_remove // 2
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
else: else:
ids_to_move = second_remove // 2
pair_ids_to_move = first_remove + second_remove - (second_remove // 2)
if self.truncation_side == "right": if self.truncation_side == "right":
pair_ids = pair_ids[:-1] ids = ids[:-ids_to_move] if ids_to_move > 0 else ids
pair_ids = pair_ids[:-pair_ids_to_move] if pair_ids is not None and pair_ids_to_move > 0 else pair_ids
elif self.truncation_side == "left": elif self.truncation_side == "left":
pair_ids = pair_ids[1:] ids = ids[ids_to_move:]
pair_ids = pair_ids[pair_ids_to_move:] if pair_ids is not None else None
else: else:
raise ValueError("invalid truncation strategy:" + str(self.truncation_side)) raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove: if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove) window_len = min(len(pair_ids), stride + num_tokens_to_remove)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment