Unverified Commit 0a27bb0d authored by Eyal Cohen's avatar Eyal Cohen Committed by GitHub
Browse files

Optimize merge_tokens method (#3615)



Optimizes merge_tokens method as discussed in #3614 
Co-authored-by: default avatarEyal Cohen <eyal308@gmail.com>
parent 0ff6d266
......@@ -115,17 +115,14 @@ def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSp
if len(tokens) != len(scores):
raise ValueError("`tokens` and `scores` must be the same length.")
t_prev = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != t_prev:
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
t_prev = token
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
diff = torch.diff(
tokens, prepend=torch.tensor([-1], device=tokens.device), append=torch.tensor([-1], device=tokens.device)
)
changes_wo_blank = torch.nonzero((diff != 0)).squeeze().tolist()
tokens = tokens.tolist()
spans = [
TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:])
if (token := tokens[start]) != blank
]
return spans
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment