"tests/python/pytorch/sparse/test_matmul.py" did not exist on "0698e91a0e4b40bd4a5a4e59205d098e1bb3d3c9"
Unverified Commit 6bfd83b4 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add edit_distance

parent bac32ec1
...@@ -11,12 +11,12 @@ from torch.optim import SGD, Adadelta, Adam, AdamW ...@@ -11,12 +11,12 @@ from torch.optim import SGD, Adadelta, Adam, AdamW
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator from torchaudio.datasets.utils import bg_iterator
from torchaudio.functional import edit_distance
from torchaudio.models.wav2letter import Wav2Letter from torchaudio.models.wav2letter import Wav2Letter
from ctc_decoders import GreedyDecoder from ctc_decoders import GreedyDecoder
from datasets import collate_factory, split_process_librispeech from datasets import collate_factory, split_process_librispeech
from languagemodels import LanguageModel from languagemodels import LanguageModel
from metrics import levenshtein_distance
from transforms import Normalize, UnsqueezeFirst from transforms import Normalize, UnsqueezeFirst
from utils import MetricLogger, count_parameters, save_checkpoint from utils import MetricLogger, count_parameters, save_checkpoint
...@@ -217,7 +217,7 @@ def compute_error_rates(outputs, targets, decoder, language_model, metric): ...@@ -217,7 +217,7 @@ def compute_error_rates(outputs, targets, decoder, language_model, metric):
target_print = target[i].ljust(print_length)[:print_length] target_print = target[i].ljust(print_length)[:print_length]
logging.info("Target: %s Output: %s", target_print, output_print) logging.info("Target: %s Output: %s", target_print, output_print)
cers = [levenshtein_distance(t, o) for t, o in zip(target, output)] cers = [edit_distance(t, o) for t, o in zip(target, output)]
cers = sum(cers) cers = sum(cers)
n = sum(len(t) for t in target) n = sum(len(t) for t in target)
metric["batch char error"] = cers metric["batch char error"] = cers
...@@ -232,7 +232,7 @@ def compute_error_rates(outputs, targets, decoder, language_model, metric): ...@@ -232,7 +232,7 @@ def compute_error_rates(outputs, targets, decoder, language_model, metric):
output = [o.split(language_model.char_space) for o in output] output = [o.split(language_model.char_space) for o in output]
target = [t.split(language_model.char_space) for t in target] target = [t.split(language_model.char_space) for t in target]
wers = [levenshtein_distance(t, o) for t, o in zip(target, output)] wers = [edit_distance(t, o) for t, o in zip(target, output)]
wers = sum(wers) wers = sum(wers)
n = sum(len(t) for t in target) n = sum(len(t) for t in target)
metric["batch word error"] = wers metric["batch word error"] = wers
......
from typing import List, Union
def levenshtein_distance(r: Union[str, List[str]], h: Union[str, List[str]]):
"""
Calculate the Levenshtein distance between two lists or strings.
"""
# Initialisation
dold = list(range(len(h) + 1))
dnew = list(0 for _ in range(len(h) + 1))
# Computation
for i in range(1, len(r) + 1):
dnew[0] = i
for j in range(1, len(h) + 1):
if r[i - 1] == h[j - 1]:
dnew[j] = dold[j - 1]
else:
substitution = dold[j - 1] + 1
insertion = dnew[j - 1] + 1
deletion = dold[j] + 1
dnew[j] = min(substitution, insertion, deletion)
dnew, dold = dold, dnew
return dold[-1]
if __name__ == "__main__":
assert levenshtein_distance("abc", "abc") == 0
assert levenshtein_distance("aaa", "aba") == 1
assert levenshtein_distance("aba", "aaa") == 1
assert levenshtein_distance("aa", "aaa") == 1
assert levenshtein_distance("aaa", "aa") == 1
assert levenshtein_distance("abc", "bcd") == 2
assert levenshtein_distance(["hello", "world"], ["hello", "world", "!"]) == 1
assert levenshtein_distance(["hello", "world"], ["world", "hello", "!"]) == 2
...@@ -382,6 +382,46 @@ class Functional(TestBaseMixin): ...@@ -382,6 +382,46 @@ class Functional(TestBaseMixin):
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape assert output_shape == expected_shape
@parameterized.expand(
[
# words
["", "", 0], # equal
["abc", "abc", 0],
["ᑌᑎIᑕO", "ᑌᑎIᑕO", 0],
["abc", "", 3], # deletion
["aa", "aaa", 1],
["aaa", "aa", 1],
["ᑌᑎI", "ᑌᑎIᑕO", 2],
["aaa", "aba", 1], # substitution
["aba", "aaa", 1],
["aba", " ", 3],
["abc", "bcd", 2], # mix deletion and substitution
["0ᑌᑎI", "ᑌᑎIᑕO", 3],
# sentences
[["hello", "", "Tᕮ᙭T"], ["hello", "", "Tᕮ᙭T"], 0], # equal
[[], [], 0],
[["hello", "world"], ["hello", "world", "!"], 1], # deletion
[["hello", "world"], ["world"], 1],
[["hello", "world"], [], 2],
[["Tᕮ᙭T", ], ["world"], 1], # substitution
[["Tᕮ᙭T", "XD"], ["world", "hello"], 2],
[["", "XD"], ["world", ""], 2],
["aba", " ", 3],
[["hello", "world"], ["world", "hello", "!"], 2], # mix deletion and substitution
[["Tᕮ᙭T", "world", "LOL", "XD"], ["world", "hello", "ʕ•́ᴥ•̀ʔっ"], 3],
]
)
def test_simple_case_edit_distance(self, seq1, seq2, distance):
assert F.edit_distance(seq1, seq2) == distance
assert F.edit_distance(seq2, seq1) == distance
class FunctionalCPUOnly(TestBaseMixin): class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self): def test_create_fb_matrix_no_warning_high_n_freq(self):
......
...@@ -20,6 +20,7 @@ from .functional import ( ...@@ -20,6 +20,7 @@ from .functional import (
spectral_centroid, spectral_centroid,
apply_codec, apply_codec,
resample, resample,
edit_distance,
) )
from .filtering import ( from .filtering import (
allpass_biquad, allpass_biquad,
...@@ -88,4 +89,5 @@ __all__ = [ ...@@ -88,4 +89,5 @@ __all__ = [
'vad', 'vad',
'apply_codec', 'apply_codec',
'resample', 'resample',
'edit_distance',
] ]
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from collections.abc import Sequence
import io import io
import math import math
import warnings import warnings
...@@ -34,6 +35,7 @@ __all__ = [ ...@@ -34,6 +35,7 @@ __all__ = [
"spectral_centroid", "spectral_centroid",
"apply_codec", "apply_codec",
"resample", "resample",
"edit_distance",
] ]
...@@ -1444,3 +1446,45 @@ def resample( ...@@ -1444,3 +1446,45 @@ def resample(
resampling_method, beta, waveform.device, waveform.dtype) resampling_method, beta, waveform.device, waveform.dtype)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width) resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled return resampled
@torch.jit.unused
def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
"""
Calculate the word level edit (Levenshtein) distance between two sequences.
The function computes an edit distance allowing deletion, insertion and
substitution. The result is an integer.
For most applications, the two input sequences should be the same type. If
two strings are given, the output is the edit distance between the two
strings (character edit distance). If two lists of strings are given, the
output is the edit distance between sentences (word edit distance). Users
may want to normalize the output by the length of the reference sequence.
torchscipt is not supported for this function.
Args:
seq1 (Sequence): the first sequence to compare.
seq2 (Sequence): the second sequence to compare.
Returns:
int: The distance between the first and second sequences.
"""
len_sent2 = len(seq2)
dold = list(range(len_sent2 + 1))
dnew = [0 for _ in range(len_sent2 + 1)]
for i in range(1, len(seq1) + 1):
dnew[0] = i
for j in range(1, len_sent2 + 1):
if seq1[i - 1] == seq2[j - 1]:
dnew[j] = dold[j - 1]
else:
substitution = dold[j - 1] + 1
insertion = dnew[j - 1] + 1
deletion = dold[j] + 1
dnew[j] = min(substitution, insertion, deletion)
dnew, dold = dold, dnew
return int(dold[-1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment