Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
6bfd83b4
Unverified
Commit
6bfd83b4
authored
Jun 25, 2021
by
yangarbiter
Committed by
GitHub
Jun 25, 2021
Browse files
Add edit_distance
parent
bac32ec1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
89 additions
and
41 deletions
+89
-41
examples/pipeline_wav2letter/main.py
examples/pipeline_wav2letter/main.py
+3
-3
examples/pipeline_wav2letter/metrics.py
examples/pipeline_wav2letter/metrics.py
+0
-38
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+40
-0
torchaudio/functional/__init__.py
torchaudio/functional/__init__.py
+2
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+44
-0
No files found.
examples/pipeline_wav2letter/main.py
View file @
6bfd83b4
...
...
@@ -11,12 +11,12 @@ from torch.optim import SGD, Adadelta, Adam, AdamW
from
torch.optim.lr_scheduler
import
ExponentialLR
,
ReduceLROnPlateau
from
torch.utils.data
import
DataLoader
from
torchaudio.datasets.utils
import
bg_iterator
from
torchaudio.functional
import
edit_distance
from
torchaudio.models.wav2letter
import
Wav2Letter
from
ctc_decoders
import
GreedyDecoder
from
datasets
import
collate_factory
,
split_process_librispeech
from
languagemodels
import
LanguageModel
from
metrics
import
levenshtein_distance
from
transforms
import
Normalize
,
UnsqueezeFirst
from
utils
import
MetricLogger
,
count_parameters
,
save_checkpoint
...
...
@@ -217,7 +217,7 @@ def compute_error_rates(outputs, targets, decoder, language_model, metric):
target_print
=
target
[
i
].
ljust
(
print_length
)[:
print_length
]
logging
.
info
(
"Target: %s Output: %s"
,
target_print
,
output_print
)
cers
=
[
levenshtein
_distance
(
t
,
o
)
for
t
,
o
in
zip
(
target
,
output
)]
cers
=
[
edit
_distance
(
t
,
o
)
for
t
,
o
in
zip
(
target
,
output
)]
cers
=
sum
(
cers
)
n
=
sum
(
len
(
t
)
for
t
in
target
)
metric
[
"batch char error"
]
=
cers
...
...
@@ -232,7 +232,7 @@ def compute_error_rates(outputs, targets, decoder, language_model, metric):
output
=
[
o
.
split
(
language_model
.
char_space
)
for
o
in
output
]
target
=
[
t
.
split
(
language_model
.
char_space
)
for
t
in
target
]
wers
=
[
levenshtein
_distance
(
t
,
o
)
for
t
,
o
in
zip
(
target
,
output
)]
wers
=
[
edit
_distance
(
t
,
o
)
for
t
,
o
in
zip
(
target
,
output
)]
wers
=
sum
(
wers
)
n
=
sum
(
len
(
t
)
for
t
in
target
)
metric
[
"batch word error"
]
=
wers
...
...
examples/pipeline_wav2letter/metrics.py
deleted
100644 → 0
View file @
bac32ec1
from
typing
import
List
,
Union
def
levenshtein_distance
(
r
:
Union
[
str
,
List
[
str
]],
h
:
Union
[
str
,
List
[
str
]]):
"""
Calculate the Levenshtein distance between two lists or strings.
"""
# Initialisation
dold
=
list
(
range
(
len
(
h
)
+
1
))
dnew
=
list
(
0
for
_
in
range
(
len
(
h
)
+
1
))
# Computation
for
i
in
range
(
1
,
len
(
r
)
+
1
):
dnew
[
0
]
=
i
for
j
in
range
(
1
,
len
(
h
)
+
1
):
if
r
[
i
-
1
]
==
h
[
j
-
1
]:
dnew
[
j
]
=
dold
[
j
-
1
]
else
:
substitution
=
dold
[
j
-
1
]
+
1
insertion
=
dnew
[
j
-
1
]
+
1
deletion
=
dold
[
j
]
+
1
dnew
[
j
]
=
min
(
substitution
,
insertion
,
deletion
)
dnew
,
dold
=
dold
,
dnew
return
dold
[
-
1
]
if
__name__
==
"__main__"
:
assert
levenshtein_distance
(
"abc"
,
"abc"
)
==
0
assert
levenshtein_distance
(
"aaa"
,
"aba"
)
==
1
assert
levenshtein_distance
(
"aba"
,
"aaa"
)
==
1
assert
levenshtein_distance
(
"aa"
,
"aaa"
)
==
1
assert
levenshtein_distance
(
"aaa"
,
"aa"
)
==
1
assert
levenshtein_distance
(
"abc"
,
"bcd"
)
==
2
assert
levenshtein_distance
([
"hello"
,
"world"
],
[
"hello"
,
"world"
,
"!"
])
==
1
assert
levenshtein_distance
([
"hello"
,
"world"
],
[
"world"
,
"hello"
,
"!"
])
==
2
test/torchaudio_unittest/functional/functional_impl.py
View file @
6bfd83b4
...
...
@@ -382,6 +382,46 @@ class Functional(TestBaseMixin):
output_shape
=
(
torch
.
view_as_complex
(
spec_stretch
)
if
test_pseudo_complex
else
spec_stretch
).
shape
assert
output_shape
==
expected_shape
@
parameterized
.
expand
(
[
# words
[
""
,
""
,
0
],
# equal
[
"abc"
,
"abc"
,
0
],
[
"ᑌᑎIᑕO"
,
"ᑌᑎIᑕO"
,
0
],
[
"abc"
,
""
,
3
],
# deletion
[
"aa"
,
"aaa"
,
1
],
[
"aaa"
,
"aa"
,
1
],
[
"ᑌᑎI"
,
"ᑌᑎIᑕO"
,
2
],
[
"aaa"
,
"aba"
,
1
],
# substitution
[
"aba"
,
"aaa"
,
1
],
[
"aba"
,
" "
,
3
],
[
"abc"
,
"bcd"
,
2
],
# mix deletion and substitution
[
"0ᑌᑎI"
,
"ᑌᑎIᑕO"
,
3
],
# sentences
[[
"hello"
,
""
,
"Tᕮ᙭T"
],
[
"hello"
,
""
,
"Tᕮ᙭T"
],
0
],
# equal
[[],
[],
0
],
[[
"hello"
,
"world"
],
[
"hello"
,
"world"
,
"!"
],
1
],
# deletion
[[
"hello"
,
"world"
],
[
"world"
],
1
],
[[
"hello"
,
"world"
],
[],
2
],
[[
"Tᕮ᙭T"
,
],
[
"world"
],
1
],
# substitution
[[
"Tᕮ᙭T"
,
"XD"
],
[
"world"
,
"hello"
],
2
],
[[
""
,
"XD"
],
[
"world"
,
""
],
2
],
[
"aba"
,
" "
,
3
],
[[
"hello"
,
"world"
],
[
"world"
,
"hello"
,
"!"
],
2
],
# mix deletion and substitution
[[
"Tᕮ᙭T"
,
"world"
,
"LOL"
,
"XD"
],
[
"world"
,
"hello"
,
"ʕ•́ᴥ•̀ʔっ"
],
3
],
]
)
def
test_simple_case_edit_distance
(
self
,
seq1
,
seq2
,
distance
):
assert
F
.
edit_distance
(
seq1
,
seq2
)
==
distance
assert
F
.
edit_distance
(
seq2
,
seq1
)
==
distance
class
FunctionalCPUOnly
(
TestBaseMixin
):
def
test_create_fb_matrix_no_warning_high_n_freq
(
self
):
...
...
torchaudio/functional/__init__.py
View file @
6bfd83b4
...
...
@@ -20,6 +20,7 @@ from .functional import (
spectral_centroid
,
apply_codec
,
resample
,
edit_distance
,
)
from
.filtering
import
(
allpass_biquad
,
...
...
@@ -88,4 +89,5 @@ __all__ = [
'vad'
,
'apply_codec'
,
'resample'
,
'edit_distance'
,
]
torchaudio/functional/functional.py
View file @
6bfd83b4
# -*- coding: utf-8 -*-
from
collections.abc
import
Sequence
import
io
import
math
import
warnings
...
...
@@ -34,6 +35,7 @@ __all__ = [
"spectral_centroid"
,
"apply_codec"
,
"resample"
,
"edit_distance"
,
]
...
...
@@ -1444,3 +1446,45 @@ def resample(
resampling_method
,
beta
,
waveform
.
device
,
waveform
.
dtype
)
resampled
=
_apply_sinc_resample_kernel
(
waveform
,
orig_freq
,
new_freq
,
gcd
,
kernel
,
width
)
return
resampled
@
torch
.
jit
.
unused
def
edit_distance
(
seq1
:
Sequence
,
seq2
:
Sequence
)
->
int
:
"""
Calculate the word level edit (Levenshtein) distance between two sequences.
The function computes an edit distance allowing deletion, insertion and
substitution. The result is an integer.
For most applications, the two input sequences should be the same type. If
two strings are given, the output is the edit distance between the two
strings (character edit distance). If two lists of strings are given, the
output is the edit distance between sentences (word edit distance). Users
may want to normalize the output by the length of the reference sequence.
torchscipt is not supported for this function.
Args:
seq1 (Sequence): the first sequence to compare.
seq2 (Sequence): the second sequence to compare.
Returns:
int: The distance between the first and second sequences.
"""
len_sent2
=
len
(
seq2
)
dold
=
list
(
range
(
len_sent2
+
1
))
dnew
=
[
0
for
_
in
range
(
len_sent2
+
1
)]
for
i
in
range
(
1
,
len
(
seq1
)
+
1
):
dnew
[
0
]
=
i
for
j
in
range
(
1
,
len_sent2
+
1
):
if
seq1
[
i
-
1
]
==
seq2
[
j
-
1
]:
dnew
[
j
]
=
dold
[
j
-
1
]
else
:
substitution
=
dold
[
j
-
1
]
+
1
insertion
=
dnew
[
j
-
1
]
+
1
deletion
=
dold
[
j
]
+
1
dnew
[
j
]
=
min
(
substitution
,
insertion
,
deletion
)
dnew
,
dold
=
dold
,
dnew
return
int
(
dold
[
-
1
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment