Unverified Commit e14a2e0c authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Refactor text preprocessing tests in Tacotron2 example (#1641)

parent ec3ab990
......@@ -56,7 +56,7 @@ fi
(
set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
......
......@@ -44,7 +44,7 @@ fi
(
set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
......
......@@ -87,6 +87,8 @@ Optional packages to install if you want to run related tests:
- `transformers`
- `fairseq` (it has to be newer than `0.10.2`, so you will need to install from
source. Commit `e6eddd80` is known to work.)
- `unidecode` (dependency for testing text preprocessing functions for examples/pipeline_tacotron2)
- `inflect` (dependency for testing text preprocessing functions for examples/pipeline_tacotron2)
## Development Process
......
......@@ -30,28 +30,36 @@ import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m: re.Match) -> str:
return m.group(1).replace(',', '')
def _remove_commas(text: str) -> str:
return re.sub(_comma_number_re, lambda m: m.group(1).replace(',', ''), text)
def _expand_decimal_point(m: re.Match) -> str:
return m.group(1).replace('.', ' point ')
def _expand_pounds(text: str) -> str:
return re.sub(_pounds_re, r'\1 pounds', text)
def _expand_dollars(m: re.Match) -> str:
def _expand_dollars_repl_fn(m):
"""The replacement function for expanding dollars."""
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if len(parts) > 1 and parts[1]:
if len(parts[1]) == 1:
# handle the case where we have one digit after the decimal point
cents = int(parts[1]) * 10
else:
cents = int(parts[1])
else:
cents = 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
......@@ -66,11 +74,20 @@ def _expand_dollars(m: re.Match) -> str:
return 'zero dollars'
def _expand_ordinal(m: re.Match) -> str:
return _inflect.number_to_words(m.group(0))
def _expand_dollars(text: str) -> str:
return re.sub(_dollars_re, _expand_dollars_repl_fn, text)
def _expand_decimal_point(text: str) -> str:
return re.sub(_decimal_number_re, lambda m: m.group(1).replace('.', ' point '), text)
def _expand_number(m: re.Match) -> str:
def _expand_ordinal(text: str) -> str:
return re.sub(_ordinal_re, lambda m: _inflect.number_to_words(m.group(0)), text)
def _expand_number_repl_fn(m):
"""The replacement function for expanding number."""
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
......@@ -85,11 +102,15 @@ def _expand_number(m: re.Match) -> str:
return _inflect.number_to_words(num, andword='')
def _expand_number(text: str) -> str:
return re.sub(_number_re, _expand_number_repl_fn, text)
def normalize_numbers(text: str) -> str:
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
text = _remove_commas(text)
text = _expand_pounds(text)
text = _expand_dollars(text)
text = _expand_decimal_point(text)
text = _expand_ordinal(text)
text = _expand_number(text)
return text
import unittest
from parameterized import parameterized
from .text_preprocessing import text_to_sequence
class TestTextPreprocessor(unittest.TestCase):
@parameterized.expand(
[
["dr. Strange?", [15, 26, 14, 31, 26, 29, 11, 30, 31, 29, 12, 25, 18, 16, 10]],
["ML, is fun.", [24, 23, 6, 11, 20, 30, 11, 17, 32, 25, 7]],
["I love torchaudio!", [20, 11, 23, 26, 33, 16, 11, 31, 26, 29, 14, 19, 12, 32, 15, 20, 26, 2]],
# 'one thousand dollars, twenty cents'
["$1,000.20", [26, 25, 16, 11, 31, 19, 26, 32, 30, 12, 25, 15, 11, 15, 26, 23, 23,
12, 29, 30, 6, 11, 31, 34, 16, 25, 31, 36, 11, 14, 16, 25, 31, 30]],
]
)
def test_text_to_sequence(self, sent, seq):
assert (text_to_sequence(sent) == seq)
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule
if is_module_available("unidecode") and is_module_available("inflect"):
from pipeline_tacotron2.text.text_preprocessing import text_to_sequence
from pipeline_tacotron2.text.numbers import (
_remove_commas,
_expand_pounds,
_expand_dollars,
_expand_decimal_point,
_expand_ordinal,
_expand_number,
)
@skipIfNoModule("unidecode")
@skipIfNoModule("inflect")
class TestTextPreprocessor(TorchaudioTestCase):
@parameterized.expand(
[
["dr. Strange?", [15, 26, 14, 31, 26, 29, 11, 30, 31, 29, 12, 25, 18, 16, 10]],
["ML, is fun.", [24, 23, 6, 11, 20, 30, 11, 17, 32, 25, 7]],
["I love torchaudio!", [20, 11, 23, 26, 33, 16, 11, 31, 26, 29, 14, 19, 12, 32, 15, 20, 26, 2]],
# 'one thousand dollars, twenty cents'
["$1,000.20", [26, 25, 16, 11, 31, 19, 26, 32, 30, 12, 25, 15, 11, 15, 26, 23, 23,
12, 29, 30, 6, 11, 31, 34, 16, 25, 31, 36, 11, 14, 16, 25, 31, 30]],
]
)
def test_text_to_sequence(self, sent, seq):
assert (text_to_sequence(sent) == seq)
@parameterized.expand(
[
["He, she, and I have $1,000", "He, she, and I have $1000"],
]
)
def test_remove_commas(self, sent, truth):
assert (_remove_commas(sent) == truth)
@parameterized.expand(
[
["He, she, and I have £1000", "He, she, and I have 1000 pounds"],
]
)
def test_expand_pounds(self, sent, truth):
assert (_expand_pounds(sent) == truth)
@parameterized.expand(
[
["He, she, and I have $1000", "He, she, and I have 1000 dollars"],
["He, she, and I have $3000.01", "He, she, and I have 3000 dollars, 1 cent"],
["He has $500.20 and she has $1000.50.",
"He has 500 dollars, 20 cents and she has 1000 dollars, 50 cents."],
]
)
def test_expand_dollars(self, sent, truth):
assert (_expand_dollars(sent) == truth)
@parameterized.expand(
[
["1000.20", "1000 point 20"],
["1000.1", "1000 point 1"],
]
)
def test_expand_decimal_point(self, sent, truth):
assert (_expand_decimal_point(sent) == truth)
@parameterized.expand(
[
["21st centry", "twenty-first centry"],
["20th centry", "twentieth centry"],
["2nd place.", "second place."],
]
)
def test_expand_ordinal(self, sent, truth):
assert (_expand_ordinal(sent) == truth)
_expand_ordinal,
@parameterized.expand(
[
["100020 dollars.", "one hundred thousand twenty dollars."],
["1234567890!", "one billion, two hundred thirty-four million, "
"five hundred sixty-seven thousand, eight hundred ninety!"],
]
)
def test_expand_number(self, sent, truth):
assert (_expand_number(sent) == truth)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment