Unverified Commit 6bfd83b4 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add edit_distance

parent bac32ec1
......@@ -11,12 +11,12 @@ from torch.optim import SGD, Adadelta, Adam, AdamW
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.functional import edit_distance
from torchaudio.models.wav2letter import Wav2Letter
from ctc_decoders import GreedyDecoder
from datasets import collate_factory, split_process_librispeech
from languagemodels import LanguageModel
from metrics import levenshtein_distance
from transforms import Normalize, UnsqueezeFirst
from utils import MetricLogger, count_parameters, save_checkpoint
......@@ -217,7 +217,7 @@ def compute_error_rates(outputs, targets, decoder, language_model, metric):
target_print = target[i].ljust(print_length)[:print_length]
logging.info("Target: %s Output: %s", target_print, output_print)
cers = [levenshtein_distance(t, o) for t, o in zip(target, output)]
cers = [edit_distance(t, o) for t, o in zip(target, output)]
cers = sum(cers)
n = sum(len(t) for t in target)
metric["batch char error"] = cers
......@@ -232,7 +232,7 @@ def compute_error_rates(outputs, targets, decoder, language_model, metric):
output = [o.split(language_model.char_space) for o in output]
target = [t.split(language_model.char_space) for t in target]
wers = [levenshtein_distance(t, o) for t, o in zip(target, output)]
wers = [edit_distance(t, o) for t, o in zip(target, output)]
wers = sum(wers)
n = sum(len(t) for t in target)
metric["batch word error"] = wers
......
from typing import List, Union
def levenshtein_distance(r: Union[str, List[str]], h: Union[str, List[str]]):
"""
Calculate the Levenshtein distance between two lists or strings.
"""
# Initialisation
dold = list(range(len(h) + 1))
dnew = list(0 for _ in range(len(h) + 1))
# Computation
for i in range(1, len(r) + 1):
dnew[0] = i
for j in range(1, len(h) + 1):
if r[i - 1] == h[j - 1]:
dnew[j] = dold[j - 1]
else:
substitution = dold[j - 1] + 1
insertion = dnew[j - 1] + 1
deletion = dold[j] + 1
dnew[j] = min(substitution, insertion, deletion)
dnew, dold = dold, dnew
return dold[-1]
if __name__ == "__main__":
assert levenshtein_distance("abc", "abc") == 0
assert levenshtein_distance("aaa", "aba") == 1
assert levenshtein_distance("aba", "aaa") == 1
assert levenshtein_distance("aa", "aaa") == 1
assert levenshtein_distance("aaa", "aa") == 1
assert levenshtein_distance("abc", "bcd") == 2
assert levenshtein_distance(["hello", "world"], ["hello", "world", "!"]) == 1
assert levenshtein_distance(["hello", "world"], ["world", "hello", "!"]) == 2
......@@ -382,6 +382,46 @@ class Functional(TestBaseMixin):
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape
@parameterized.expand(
[
# words
["", "", 0], # equal
["abc", "abc", 0],
["ᑌᑎIᑕO", "ᑌᑎIᑕO", 0],
["abc", "", 3], # deletion
["aa", "aaa", 1],
["aaa", "aa", 1],
["ᑌᑎI", "ᑌᑎIᑕO", 2],
["aaa", "aba", 1], # substitution
["aba", "aaa", 1],
["aba", " ", 3],
["abc", "bcd", 2], # mix deletion and substitution
["0ᑌᑎI", "ᑌᑎIᑕO", 3],
# sentences
[["hello", "", "Tᕮ᙭T"], ["hello", "", "Tᕮ᙭T"], 0], # equal
[[], [], 0],
[["hello", "world"], ["hello", "world", "!"], 1], # deletion
[["hello", "world"], ["world"], 1],
[["hello", "world"], [], 2],
[["Tᕮ᙭T", ], ["world"], 1], # substitution
[["Tᕮ᙭T", "XD"], ["world", "hello"], 2],
[["", "XD"], ["world", ""], 2],
["aba", " ", 3],
[["hello", "world"], ["world", "hello", "!"], 2], # mix deletion and substitution
[["Tᕮ᙭T", "world", "LOL", "XD"], ["world", "hello", "ʕ•́ᴥ•̀ʔっ"], 3],
]
)
def test_simple_case_edit_distance(self, seq1, seq2, distance):
assert F.edit_distance(seq1, seq2) == distance
assert F.edit_distance(seq2, seq1) == distance
class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self):
......
......@@ -20,6 +20,7 @@ from .functional import (
spectral_centroid,
apply_codec,
resample,
edit_distance,
)
from .filtering import (
allpass_biquad,
......@@ -88,4 +89,5 @@ __all__ = [
'vad',
'apply_codec',
'resample',
'edit_distance',
]
# -*- coding: utf-8 -*-
from collections.abc import Sequence
import io
import math
import warnings
......@@ -34,6 +35,7 @@ __all__ = [
"spectral_centroid",
"apply_codec",
"resample",
"edit_distance",
]
......@@ -1444,3 +1446,45 @@ def resample(
resampling_method, beta, waveform.device, waveform.dtype)
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
return resampled
@torch.jit.unused
def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
"""
Calculate the word level edit (Levenshtein) distance between two sequences.
The function computes an edit distance allowing deletion, insertion and
substitution. The result is an integer.
For most applications, the two input sequences should be the same type. If
two strings are given, the output is the edit distance between the two
strings (character edit distance). If two lists of strings are given, the
output is the edit distance between sentences (word edit distance). Users
may want to normalize the output by the length of the reference sequence.
torchscipt is not supported for this function.
Args:
seq1 (Sequence): the first sequence to compare.
seq2 (Sequence): the second sequence to compare.
Returns:
int: The distance between the first and second sequences.
"""
len_sent2 = len(seq2)
dold = list(range(len_sent2 + 1))
dnew = [0 for _ in range(len_sent2 + 1)]
for i in range(1, len(seq1) + 1):
dnew[0] = i
for j in range(1, len_sent2 + 1):
if seq1[i - 1] == seq2[j - 1]:
dnew[j] = dold[j - 1]
else:
substitution = dold[j - 1] + 1
insertion = dnew[j - 1] + 1
deletion = dold[j] + 1
dnew[j] = min(substitution, insertion, deletion)
dnew, dold = dold, dnew
return int(dold[-1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment