"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9f10306b3fd8168a100e749716e99b75b769e3ef"
Unverified Commit 0a27bb0d authored by Eyal Cohen's avatar Eyal Cohen Committed by GitHub
Browse files

Optimize merge_tokens method (#3615)



Optimizes merge_tokens method as discussed in #3614 
Co-authored-by: default avatarEyal Cohen <eyal308@gmail.com>
parent 0ff6d266
...@@ -115,17 +115,14 @@ def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSp ...@@ -115,17 +115,14 @@ def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSp
if len(tokens) != len(scores): if len(tokens) != len(scores):
raise ValueError("`tokens` and `scores` must be the same length.") raise ValueError("`tokens` and `scores` must be the same length.")
t_prev = blank diff = torch.diff(
i = start = -1 tokens, prepend=torch.tensor([-1], device=tokens.device), append=torch.tensor([-1], device=tokens.device)
spans = [] )
for t, token in enumerate(tokens): changes_wo_blank = torch.nonzero((diff != 0)).squeeze().tolist()
if token != t_prev: tokens = tokens.tolist()
if t_prev != blank: spans = [
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item())) TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
if token != blank: for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:])
i += 1 if (token := tokens[start]) != blank
start = t ]
t_prev = token
if t_prev != blank:
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
return spans return spans
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment