Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
...@@ -14,22 +14,25 @@ See this comment for design rationale: ...@@ -14,22 +14,25 @@ See this comment for design rationale:
https://github.com/pytorch/vision/pull/1321#issuecomment-531033978 https://github.com/pytorch/vision/pull/1321#issuecomment-531033978
""" """
import os.path
import jinja2 import jinja2
from jinja2 import select_autoescape
import yaml import yaml
import os.path from jinja2 import select_autoescape
PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]
CU_VERSIONS_DICT = {"linux": ["cpu", "cu102", "cu111","cu113", "cu115", "rocm4.1"], CU_VERSIONS_DICT = {
"windows": ["cpu", "cu113", "cu115"], "linux": ["cpu", "cu102", "cu111", "cu113", "cu115", "rocm4.1"],
"macos": ["cpu"]} "windows": ["cpu", "cu113", "cu115"],
"macos": ["cpu"],
}
DOC_VERSION = ('linux', '3.8') DOC_VERSION = ("linux", "3.8")
def build_workflows(prefix='', upload=False, filter_branch=None, indentation=6): def build_workflows(prefix="", upload=False, filter_branch=None, indentation=6):
w = [] w = []
w += build_download_job(filter_branch) w += build_download_job(filter_branch)
for btype in ["wheel", "conda"]: for btype in ["wheel", "conda"]:
...@@ -37,23 +40,21 @@ def build_workflows(prefix='', upload=False, filter_branch=None, indentation=6): ...@@ -37,23 +40,21 @@ def build_workflows(prefix='', upload=False, filter_branch=None, indentation=6):
for python_version in PYTHON_VERSIONS: for python_version in PYTHON_VERSIONS:
for cu_version in CU_VERSIONS_DICT[os_type]: for cu_version in CU_VERSIONS_DICT[os_type]:
fb = filter_branch fb = filter_branch
if cu_version.startswith("rocm") and btype=="conda": if cu_version.startswith("rocm") and btype == "conda":
continue continue
if not fb and (os_type == 'linux' and if not fb and (
btype == 'wheel' and os_type == "linux" and btype == "wheel" and python_version == "3.8" and cu_version == "cpu"
python_version == '3.8' and ):
cu_version == 'cpu'):
# the fields must match the build_docs "requires" dependency # the fields must match the build_docs "requires" dependency
fb = '/.*/' fb = "/.*/"
w += build_workflow_pair(btype, os_type, python_version, cu_version, fb, prefix, upload) w += build_workflow_pair(btype, os_type, python_version, cu_version, fb, prefix, upload)
if not filter_branch: if not filter_branch:
# Build on every pull request, but upload only on nightly and tags # Build on every pull request, but upload only on nightly and tags
w += build_doc_job('/.*/') w += build_doc_job("/.*/")
w += upload_doc_job('nightly') w += upload_doc_job("nightly")
w += docstring_parameters_sync_job(None) w += docstring_parameters_sync_job(None)
return indent(indentation, w) return indent(indentation, w)
...@@ -67,7 +68,7 @@ def build_download_job(filter_branch): ...@@ -67,7 +68,7 @@ def build_download_job(filter_branch):
return [{"download_third_parties_nix": job}] return [{"download_third_parties_nix": job}]
def build_workflow_pair(btype, os_type, python_version, cu_version, filter_branch, prefix='', upload=False): def build_workflow_pair(btype, os_type, python_version, cu_version, filter_branch, prefix="", upload=False):
w = [] w = []
base_workflow_name = f"{prefix}binary_{os_type}_{btype}_py{python_version}_{cu_version}" base_workflow_name = f"{prefix}binary_{os_type}_{btype}_py{python_version}_{cu_version}"
...@@ -77,9 +78,13 @@ def build_workflow_pair(btype, os_type, python_version, cu_version, filter_branc ...@@ -77,9 +78,13 @@ def build_workflow_pair(btype, os_type, python_version, cu_version, filter_branc
w.append(generate_upload_workflow(base_workflow_name, filter_branch, os_type, btype, cu_version)) w.append(generate_upload_workflow(base_workflow_name, filter_branch, os_type, btype, cu_version))
if filter_branch == 'nightly' and os_type != 'macos': if filter_branch == "nightly" and os_type != "macos":
pydistro = 'pip' if btype == 'wheel' else 'conda' pydistro = "pip" if btype == "wheel" else "conda"
w.append(generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, python_version, cu_version, os_type)) w.append(
generate_smoketest_workflow(
pydistro, base_workflow_name, filter_branch, python_version, cu_version, os_type
)
)
return w return w
...@@ -88,7 +93,9 @@ def build_doc_job(filter_branch): ...@@ -88,7 +93,9 @@ def build_doc_job(filter_branch):
job = { job = {
"name": "build_docs", "name": "build_docs",
"python_version": "3.8", "python_version": "3.8",
"requires": ["binary_linux_wheel_py3.8_cpu", ], "requires": [
"binary_linux_wheel_py3.8_cpu",
],
} }
if filter_branch: if filter_branch:
...@@ -101,7 +108,9 @@ def upload_doc_job(filter_branch): ...@@ -101,7 +108,9 @@ def upload_doc_job(filter_branch):
"name": "upload_docs", "name": "upload_docs",
"context": "org-member", "context": "org-member",
"python_version": "3.8", "python_version": "3.8",
"requires": ["build_docs", ], "requires": [
"build_docs",
],
} }
if filter_branch: if filter_branch:
...@@ -113,7 +122,9 @@ def docstring_parameters_sync_job(filter_branch): ...@@ -113,7 +122,9 @@ def docstring_parameters_sync_job(filter_branch):
job = { job = {
"name": "docstring_parameters_sync", "name": "docstring_parameters_sync",
"python_version": "3.8", "python_version": "3.8",
"requires": ["binary_linux_wheel_py3.8_cpu", ], "requires": [
"binary_linux_wheel_py3.8_cpu",
],
} }
if filter_branch: if filter_branch:
...@@ -129,13 +140,13 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, filte ...@@ -129,13 +140,13 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, filte
"cuda_version": cu_version, "cuda_version": cu_version,
} }
if os_type in ['linux', 'macos']: if os_type in ["linux", "macos"]:
d['requires'] = ['download_third_parties_nix'] d["requires"] = ["download_third_parties_nix"]
if btype == 'conda': if btype == "conda":
d['conda_docker_image'] = f'pytorch/conda-builder:{cu_version.replace("cu1","cuda1")}' d["conda_docker_image"] = f'pytorch/conda-builder:{cu_version.replace("cu1","cuda1")}'
elif cu_version.startswith('cu'): elif cu_version.startswith("cu"):
d['wheel_docker_image'] = f'pytorch/manylinux-{cu_version.replace("cu1","cuda1")}' d["wheel_docker_image"] = f'pytorch/manylinux-{cu_version.replace("cu1","cuda1")}'
elif cu_version.startswith('rocm'): elif cu_version.startswith("rocm"):
d["wheel_docker_image"] = f"pytorch/manylinux-rocm:{cu_version[len('rocm'):]}" d["wheel_docker_image"] = f"pytorch/manylinux-rocm:{cu_version[len('rocm'):]}"
if filter_branch: if filter_branch:
...@@ -153,7 +164,7 @@ def gen_filter_branch_tree(*branches): ...@@ -153,7 +164,7 @@ def gen_filter_branch_tree(*branches):
# Using a raw string here to avoid having to escape # Using a raw string here to avoid having to escape
# anything # anything
"only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" "only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/"
} },
} }
...@@ -164,9 +175,8 @@ def generate_upload_workflow(base_workflow_name, filter_branch, os_type, btype, ...@@ -164,9 +175,8 @@ def generate_upload_workflow(base_workflow_name, filter_branch, os_type, btype,
"requires": [base_workflow_name], "requires": [base_workflow_name],
} }
if btype == 'wheel': if btype == "wheel":
d["subfolder"] = "" if os_type == 'macos' else cu_version + "/" d["subfolder"] = "" if os_type == "macos" else cu_version + "/"
if filter_branch: if filter_branch:
d["filters"] = gen_filter_branch_tree(filter_branch) d["filters"] = gen_filter_branch_tree(filter_branch)
...@@ -212,22 +222,24 @@ def unittest_workflows(indentation=6): ...@@ -212,22 +222,24 @@ def unittest_workflows(indentation=6):
job = { job = {
"name": f"unittest_{os_type}_{device_type}_py{python_version}", "name": f"unittest_{os_type}_{device_type}_py{python_version}",
"python_version": python_version, "python_version": python_version,
"cuda_version": 'cpu' if device_type == "cpu" else "cu113", "cuda_version": "cpu" if device_type == "cpu" else "cu113",
} }
if os_type != "windows": if os_type != "windows":
job['requires'] = ['download_third_parties_nix'] job["requires"] = ["download_third_parties_nix"]
jobs.append({f"unittest_{os_type}_{device_type}": job}) jobs.append({f"unittest_{os_type}_{device_type}": job})
if i == 0 and os_type == "linux" and device_type == "cpu": if i == 0 and os_type == "linux" and device_type == "cpu":
jobs.append({ jobs.append(
"stylecheck": { {
"name": f"stylecheck_py{python_version}", "stylecheck": {
"python_version": python_version, "name": f"stylecheck_py{python_version}",
"cuda_version": "cpu", "python_version": python_version,
"cuda_version": "cpu",
}
} }
}) )
return indent(indentation, jobs) return indent(indentation, jobs)
...@@ -236,12 +248,14 @@ if __name__ == "__main__": ...@@ -236,12 +248,14 @@ if __name__ == "__main__":
env = jinja2.Environment( env = jinja2.Environment(
loader=jinja2.FileSystemLoader(d), loader=jinja2.FileSystemLoader(d),
lstrip_blocks=True, lstrip_blocks=True,
autoescape=select_autoescape(enabled_extensions=('html', 'xml')), autoescape=select_autoescape(enabled_extensions=("html", "xml")),
) )
with open(os.path.join(d, 'config.yml'), 'w') as f: with open(os.path.join(d, "config.yml"), "w") as f:
f.write(env.get_template('config.yml.in').render( f.write(
build_workflows=build_workflows, env.get_template("config.yml.in").render(
unittest_workflows=unittest_workflows, build_workflows=build_workflows,
)) unittest_workflows=unittest_workflows,
)
)
f.write("\n") f.write("\n")
...@@ -19,7 +19,6 @@ import signal ...@@ -19,7 +19,6 @@ import signal
import subprocess import subprocess
import sys import sys
import traceback import traceback
from functools import partial from functools import partial
try: try:
...@@ -28,7 +27,7 @@ except ImportError: ...@@ -28,7 +27,7 @@ except ImportError:
DEVNULL = open(os.devnull, "wb") DEVNULL = open(os.devnull, "wb")
DEFAULT_EXTENSIONS = 'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu' DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu"
class ExitStatus: class ExitStatus:
...@@ -52,14 +51,8 @@ def list_files(files, recursive=False, extensions=None, exclude=None): ...@@ -52,14 +51,8 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
# os.walk() supports trimming down the dnames list # os.walk() supports trimming down the dnames list
# by modifying it in-place, # by modifying it in-place,
# to avoid unnecessary directory listings. # to avoid unnecessary directory listings.
dnames[:] = [ dnames[:] = [x for x in dnames if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)]
x for x in dnames fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)]
if
not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)
]
fpaths = [
x for x in fpaths if not fnmatch.fnmatch(x, pattern)
]
for f in fpaths: for f in fpaths:
ext = os.path.splitext(f)[1][1:] ext = os.path.splitext(f)[1][1:]
if ext in extensions: if ext in extensions:
...@@ -72,11 +65,9 @@ def list_files(files, recursive=False, extensions=None, exclude=None): ...@@ -72,11 +65,9 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
def make_diff(file, original, reformatted): def make_diff(file, original, reformatted):
return list( return list(
difflib.unified_diff( difflib.unified_diff(
original, original, reformatted, fromfile="{}\t(original)".format(file), tofile="{}\t(reformatted)".format(file), n=3
reformatted, )
fromfile='{}\t(original)'.format(file), )
tofile='{}\t(reformatted)'.format(file),
n=3))
class DiffError(Exception): class DiffError(Exception):
...@@ -99,13 +90,12 @@ def run_clang_format_diff_wrapper(args, file): ...@@ -99,13 +90,12 @@ def run_clang_format_diff_wrapper(args, file):
except DiffError: except DiffError:
raise raise
except Exception as e: except Exception as e:
raise UnexpectedError('{}: {}: {}'.format(file, e.__class__.__name__, raise UnexpectedError("{}: {}: {}".format(file, e.__class__.__name__, e), e)
e), e)
def run_clang_format_diff(args, file): def run_clang_format_diff(args, file):
try: try:
with io.open(file, 'r', encoding='utf-8') as f: with io.open(file, "r", encoding="utf-8") as f:
original = f.readlines() original = f.readlines()
except IOError as exc: except IOError as exc:
raise DiffError(str(exc)) raise DiffError(str(exc))
...@@ -130,17 +120,10 @@ def run_clang_format_diff(args, file): ...@@ -130,17 +120,10 @@ def run_clang_format_diff(args, file):
try: try:
proc = subprocess.Popen( proc = subprocess.Popen(
invocation, invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, encoding="utf-8"
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
encoding='utf-8')
except OSError as exc:
raise DiffError(
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(invocation), exc
)
) )
except OSError as exc:
raise DiffError("Command '{}' failed to start: {}".format(subprocess.list2cmdline(invocation), exc))
proc_stdout = proc.stdout proc_stdout = proc.stdout
proc_stderr = proc.stderr proc_stderr = proc.stderr
...@@ -159,30 +142,30 @@ def run_clang_format_diff(args, file): ...@@ -159,30 +142,30 @@ def run_clang_format_diff(args, file):
def bold_red(s): def bold_red(s):
return '\x1b[1m\x1b[31m' + s + '\x1b[0m' return "\x1b[1m\x1b[31m" + s + "\x1b[0m"
def colorize(diff_lines): def colorize(diff_lines):
def bold(s): def bold(s):
return '\x1b[1m' + s + '\x1b[0m' return "\x1b[1m" + s + "\x1b[0m"
def cyan(s): def cyan(s):
return '\x1b[36m' + s + '\x1b[0m' return "\x1b[36m" + s + "\x1b[0m"
def green(s): def green(s):
return '\x1b[32m' + s + '\x1b[0m' return "\x1b[32m" + s + "\x1b[0m"
def red(s): def red(s):
return '\x1b[31m' + s + '\x1b[0m' return "\x1b[31m" + s + "\x1b[0m"
for line in diff_lines: for line in diff_lines:
if line[:4] in ['--- ', '+++ ']: if line[:4] in ["--- ", "+++ "]:
yield bold(line) yield bold(line)
elif line.startswith('@@ '): elif line.startswith("@@ "):
yield cyan(line) yield cyan(line)
elif line.startswith('+'): elif line.startswith("+"):
yield green(line) yield green(line)
elif line.startswith('-'): elif line.startswith("-"):
yield red(line) yield red(line)
else: else:
yield line yield line
...@@ -195,7 +178,7 @@ def print_diff(diff_lines, use_color): ...@@ -195,7 +178,7 @@ def print_diff(diff_lines, use_color):
def print_trouble(prog, message, use_colors): def print_trouble(prog, message, use_colors):
error_text = 'error:' error_text = "error:"
if use_colors: if use_colors:
error_text = bold_red(error_text) error_text = bold_red(error_text)
print("{}: {} {}".format(prog, error_text, message), file=sys.stderr) print("{}: {} {}".format(prog, error_text, message), file=sys.stderr)
...@@ -204,45 +187,37 @@ def print_trouble(prog, message, use_colors): ...@@ -204,45 +187,37 @@ def print_trouble(prog, message, use_colors):
def main(): def main():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
'--clang-format-executable', "--clang-format-executable",
metavar='EXECUTABLE', metavar="EXECUTABLE",
help='path to the clang-format executable', help="path to the clang-format executable",
default='clang-format') default="clang-format",
parser.add_argument( )
'--extensions',
help='comma separated list of file extensions (default: {})'.format(
DEFAULT_EXTENSIONS),
default=DEFAULT_EXTENSIONS)
parser.add_argument( parser.add_argument(
'-r', "--extensions",
'--recursive', help="comma separated list of file extensions (default: {})".format(DEFAULT_EXTENSIONS),
action='store_true', default=DEFAULT_EXTENSIONS,
help='run recursively over directories') )
parser.add_argument('files', metavar='file', nargs='+') parser.add_argument("-r", "--recursive", action="store_true", help="run recursively over directories")
parser.add_argument("files", metavar="file", nargs="+")
parser.add_argument("-q", "--quiet", action="store_true")
parser.add_argument( parser.add_argument(
'-q', "-j",
'--quiet', metavar="N",
action='store_true')
parser.add_argument(
'-j',
metavar='N',
type=int, type=int,
default=0, default=0,
help='run N clang-format jobs in parallel' help="run N clang-format jobs in parallel" " (default number of cpus + 1)",
' (default number of cpus + 1)') )
parser.add_argument( parser.add_argument(
'--color', "--color", default="auto", choices=["auto", "always", "never"], help="show colored diff (default: auto)"
default='auto', )
choices=['auto', 'always', 'never'],
help='show colored diff (default: auto)')
parser.add_argument( parser.add_argument(
'-e', "-e",
'--exclude', "--exclude",
metavar='PATTERN', metavar="PATTERN",
action='append', action="append",
default=[], default=[],
help='exclude paths matching the given glob-like pattern(s)' help="exclude paths matching the given glob-like pattern(s)" " from recursive search",
' from recursive search') )
args = parser.parse_args() args = parser.parse_args()
...@@ -259,10 +234,10 @@ def main(): ...@@ -259,10 +234,10 @@ def main():
colored_stdout = False colored_stdout = False
colored_stderr = False colored_stderr = False
if args.color == 'always': if args.color == "always":
colored_stdout = True colored_stdout = True
colored_stderr = True colored_stderr = True
elif args.color == 'auto': elif args.color == "auto":
colored_stdout = sys.stdout.isatty() colored_stdout = sys.stdout.isatty()
colored_stderr = sys.stderr.isatty() colored_stderr = sys.stderr.isatty()
...@@ -275,19 +250,15 @@ def main(): ...@@ -275,19 +250,15 @@ def main():
except OSError as e: except OSError as e:
print_trouble( print_trouble(
parser.prog, parser.prog,
"Command '{}' failed to start: {}".format( "Command '{}' failed to start: {}".format(subprocess.list2cmdline(version_invocation), e),
subprocess.list2cmdline(version_invocation), e
),
use_colors=colored_stderr, use_colors=colored_stderr,
) )
return ExitStatus.TROUBLE return ExitStatus.TROUBLE
retcode = ExitStatus.SUCCESS retcode = ExitStatus.SUCCESS
files = list_files( files = list_files(
args.files, args.files, recursive=args.recursive, exclude=args.exclude, extensions=args.extensions.split(",")
recursive=args.recursive, )
exclude=args.exclude,
extensions=args.extensions.split(','))
if not files: if not files:
return return
...@@ -304,8 +275,7 @@ def main(): ...@@ -304,8 +275,7 @@ def main():
pool = None pool = None
else: else:
pool = multiprocessing.Pool(njobs) pool = multiprocessing.Pool(njobs)
it = pool.imap_unordered( it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files)
partial(run_clang_format_diff_wrapper, args), files)
while True: while True:
try: try:
outs, errs = next(it) outs, errs = next(it)
...@@ -336,5 +306,5 @@ def main(): ...@@ -336,5 +306,5 @@ def main():
return retcode return retcode
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main()) sys.exit(main())
...@@ -45,13 +45,13 @@ def query_torchaudio(cmd: str, *, accept) -> Any: ...@@ -45,13 +45,13 @@ def query_torchaudio(cmd: str, *, accept) -> Any:
def get_pr_merger_and_number(commit_hash: str) -> Optional[str]: def get_pr_merger_and_number(commit_hash: str) -> Optional[str]:
data = query_torchaudio(f"commits/{commit_hash}", accept="application/vnd.github.v3+json") data = query_torchaudio(f"commits/{commit_hash}", accept="application/vnd.github.v3+json")
commit_message = data['commit']['message'] commit_message = data["commit"]["message"]
pulled_by = commit_message.split('Pulled By: ') pulled_by = commit_message.split("Pulled By: ")
pulled_by = pulled_by[1].split('\n')[0] if len(pulled_by) > 1 else None pulled_by = pulled_by[1].split("\n")[0] if len(pulled_by) > 1 else None
pr_number = commit_message.split('Pull Request resolved: https://github.com/pytorch/audio/pull/') pr_number = commit_message.split("Pull Request resolved: https://github.com/pytorch/audio/pull/")
pr_number = pr_number[1].split('\n')[0] if len(pr_number) > 1 else None pr_number = pr_number[1].split("\n")[0] if len(pr_number) > 1 else None
return pulled_by, pr_number return pulled_by, pr_number
......
...@@ -22,97 +22,98 @@ ...@@ -22,97 +22,98 @@
# sys.path.insert(0, os.path.abspath('.')) # sys.path.insert(0, os.path.abspath('.'))
import os import os
import re import re
import pytorch_sphinx_theme import pytorch_sphinx_theme
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here. # If your documentation needs a minimal Sphinx version, state it here.
# #
needs_sphinx = '1.6' needs_sphinx = "1.6"
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.autosummary', "sphinx.ext.autosummary",
'sphinx.ext.doctest', "sphinx.ext.doctest",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.todo', "sphinx.ext.todo",
'sphinx.ext.coverage', "sphinx.ext.coverage",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinxcontrib.katex', "sphinxcontrib.katex",
'sphinxcontrib.bibtex', "sphinxcontrib.bibtex",
'sphinx_gallery.gen_gallery', "sphinx_gallery.gen_gallery",
] ]
# katex options # katex options
# #
# #
katex_options = r''' katex_options = r"""
delimiters : [ delimiters : [
{left: "$$", right: "$$", display: true}, {left: "$$", right: "$$", display: true},
{left: "\\(", right: "\\)", display: false}, {left: "\\(", right: "\\)", display: false},
{left: "\\[", right: "\\]", display: true} {left: "\\[", right: "\\]", display: true}
] ]
''' """
bibtex_bibfiles = ['refs.bib'] bibtex_bibfiles = ["refs.bib"]
def _get_var(var, default=False): def _get_var(var, default=False):
if var not in os.environ: if var not in os.environ:
return default return default
val = os.environ.get(var, '0') val = os.environ.get(var, "0")
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES'] trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"]
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO'] falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"]
if val in trues: if val in trues:
return True return True
if val not in falses: if val not in falses:
print( print(
f' --- WARNING: Unexpected environment variable value `{var}={val}`. ' f" --- WARNING: Unexpected environment variable value `{var}={val}`. " f"Expected one of {trues + falses}"
f'Expected one of {trues + falses}') )
return False return False
def _get_pattern(): def _get_pattern():
pattern = os.getenv('GALLERY_PATTERN') pattern = os.getenv("GALLERY_PATTERN")
# If BUILD_GALLERY is falsy -> no build # If BUILD_GALLERY is falsy -> no build
# If BUILD_GALLERY is truey -> build # If BUILD_GALLERY is truey -> build
# If BUILD_GALLERY is undefined # If BUILD_GALLERY is undefined
# If GALLERY_PATTERN is defined -> build # If GALLERY_PATTERN is defined -> build
# If GALLERY_PATTERN is not defined -> not build # If GALLERY_PATTERN is not defined -> not build
if not _get_var('BUILD_GALLERY', default=False if pattern is None else True): if not _get_var("BUILD_GALLERY", default=False if pattern is None else True):
if pattern is not None: if pattern is not None:
print( print(
' --- WARNING: "GALLERY_PATTERN" is provided, but "BUILD_GALLERY" value is falsy. ' ' --- WARNING: "GALLERY_PATTERN" is provided, but "BUILD_GALLERY" value is falsy. '
'Sphinx galleries are not built. To build galleries, set `BUILD_GALLERY=1`.' "Sphinx galleries are not built. To build galleries, set `BUILD_GALLERY=1`."
) )
return { return {
'ignore_pattern': r'\.py', "ignore_pattern": r"\.py",
} }
ret = {'filename_pattern': 'tutorial.py'} ret = {"filename_pattern": "tutorial.py"}
if os.getenv('GALLERY_PATTERN'): if os.getenv("GALLERY_PATTERN"):
# See https://github.com/pytorch/tutorials/blob/cbf2238df0e78d84c15bd94288966d2f4b2e83ae/conf.py#L75-L83 # See https://github.com/pytorch/tutorials/blob/cbf2238df0e78d84c15bd94288966d2f4b2e83ae/conf.py#L75-L83
ret['ignore_pattern'] = r'/(?!' + re.escape(os.getenv('GALLERY_PATTERN')) + r')[^/]+$' ret["ignore_pattern"] = r"/(?!" + re.escape(os.getenv("GALLERY_PATTERN")) + r")[^/]+$"
return ret return ret
sphinx_gallery_conf = { sphinx_gallery_conf = {
'examples_dirs': [ "examples_dirs": [
'../../examples/tutorials', "../../examples/tutorials",
], ],
'gallery_dirs': [ "gallery_dirs": [
'tutorials', "tutorials",
], ],
**_get_pattern(), **_get_pattern(),
'backreferences_dir': 'gen_modules/backreferences', "backreferences_dir": "gen_modules/backreferences",
'first_notebook_cell': None, "first_notebook_cell": None,
'doc_module': ('torchaudio',), "doc_module": ("torchaudio",),
} }
autosummary_generate = True autosummary_generate = True
...@@ -121,21 +122,21 @@ napoleon_numpy_docstring = False ...@@ -121,21 +122,21 @@ napoleon_numpy_docstring = False
napoleon_google_docstring = True napoleon_google_docstring = True
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
# source_suffix = ['.rst', '.md'] # source_suffix = ['.rst', '.md']
source_suffix = '.rst' source_suffix = ".rst"
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = 'Torchaudio' project = "Torchaudio"
copyright = '2018, Torchaudio Contributors' copyright = "2018, Torchaudio Contributors"
author = 'Torchaudio Contributors' author = "Torchaudio Contributors"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
...@@ -143,10 +144,10 @@ author = 'Torchaudio Contributors' ...@@ -143,10 +144,10 @@ author = 'Torchaudio Contributors'
# #
# The short X.Y version. # The short X.Y version.
# TODO: change to [:2] at v1.0 # TODO: change to [:2] at v1.0
version = 'main ' version = "main "
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
# TODO: verify this works as expected # TODO: verify this works as expected
release = 'main' release = "main"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
...@@ -158,10 +159,10 @@ language = None ...@@ -158,10 +159,10 @@ language = None
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path # This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['*/index.rst'] exclude_patterns = ["*/index.rst"]
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing. # If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True todo_include_todos = True
...@@ -172,7 +173,7 @@ todo_include_todos = True ...@@ -172,7 +173,7 @@ todo_include_todos = True
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'pytorch_sphinx_theme' html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
...@@ -180,28 +181,26 @@ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] ...@@ -180,28 +181,26 @@ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# documentation. # documentation.
# #
html_theme_options = { html_theme_options = {
'pytorch_project': 'audio', "pytorch_project": "audio",
'collapse_navigation': False, "collapse_navigation": False,
'display_version': True, "display_version": True,
'logo_only': True, "logo_only": True,
'navigation_with_keys': True, "navigation_with_keys": True,
'analytics_id': 'UA-117752657-2', "analytics_id": "UA-117752657-2",
} }
html_logo = '_static/img/pytorch-logo-dark.svg' html_logo = "_static/img/pytorch-logo-dark.svg"
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
html_css_files = [ html_css_files = ["https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css"]
'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css'
]
# -- Options for HTMLHelp output ------------------------------------------ # -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'TorchAudiodoc' htmlhelp_basename = "TorchAudiodoc"
# -- Options for LaTeX output --------------------------------------------- # -- Options for LaTeX output ---------------------------------------------
...@@ -210,15 +209,12 @@ latex_elements = { ...@@ -210,15 +209,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
...@@ -228,8 +224,7 @@ latex_elements = { ...@@ -228,8 +224,7 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'pytorch.tex', 'Torchaudio Documentation', (master_doc, "pytorch.tex", "Torchaudio Documentation", "Torch Contributors", "manual"),
'Torch Contributors', 'manual'),
] ]
...@@ -237,10 +232,7 @@ latex_documents = [ ...@@ -237,10 +232,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [(master_doc, "Torchaudio", "Torchaudio Documentation", [author], 1)]
(master_doc, 'Torchaudio', 'Torchaudio Documentation',
[author], 1)
]
# -- Options for Texinfo output ------------------------------------------- # -- Options for Texinfo output -------------------------------------------
...@@ -249,25 +241,31 @@ man_pages = [ ...@@ -249,25 +241,31 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'Torchaudio', 'Torchaudio Documentation', (
author, 'Torchaudio', 'Load audio files into pytorch tensors.', master_doc,
'Miscellaneous'), "Torchaudio",
"Torchaudio Documentation",
author,
"Torchaudio",
"Load audio files into pytorch tensors.",
"Miscellaneous",
),
] ]
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = { intersphinx_mapping = {
'python': ('https://docs.python.org/3/', None), "python": ("https://docs.python.org/3/", None),
'numpy': ('https://numpy.org/doc/stable/', None), "numpy": ("https://numpy.org/doc/stable/", None),
'torch': ('https://pytorch.org/docs/stable/', None), "torch": ("https://pytorch.org/docs/stable/", None),
} }
# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043 # See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes from sphinx import addnodes
from sphinx.util.docfields import TypedField
def patched_make_field(self, types, domain, items, **kw): def patched_make_field(self, types, domain, items, **kw):
...@@ -277,39 +275,39 @@ def patched_make_field(self, types, domain, items, **kw): ...@@ -277,39 +275,39 @@ def patched_make_field(self, types, domain, items, **kw):
# type: (list, str, tuple) -> nodes.field # type: (list, str, tuple) -> nodes.field
def handle_item(fieldarg, content): def handle_item(fieldarg, content):
par = nodes.paragraph() par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added par += addnodes.literal_strong("", fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg, # par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong)) # addnodes.literal_strong))
if fieldarg in types: if fieldarg in types:
par += nodes.Text(' (') par += nodes.Text(" (")
# NOTE: using .pop() here to prevent a single type node to be # NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to # inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved # inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg) fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype) typename = "".join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int') typename = typename.replace("int", "python:int")
typename = typename.replace('long', 'python:long') typename = typename.replace("long", "python:long")
typename = typename.replace('float', 'python:float') typename = typename.replace("float", "python:float")
typename = typename.replace('type', 'python:type') typename = typename.replace("type", "python:type")
par.extend(self.make_xrefs(self.typerolename, domain, typename, par.extend(self.make_xrefs(self.typerolename, domain, typename, addnodes.literal_emphasis, **kw))
addnodes.literal_emphasis, **kw))
else: else:
par += fieldtype par += fieldtype
par += nodes.Text(')') par += nodes.Text(")")
par += nodes.Text(' -- ') par += nodes.Text(" -- ")
par += content par += content
return par return par
fieldname = nodes.field_name('', self.label) fieldname = nodes.field_name("", self.label)
if len(items) == 1 and self.can_collapse: if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0] fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content) bodynode = handle_item(fieldarg, content)
else: else:
bodynode = self.list_type() bodynode = self.list_type()
for fieldarg, content in items: for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content)) bodynode += nodes.list_item("", handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode) fieldbody = nodes.field_body("", bodynode)
return nodes.field('', fieldname, fieldbody) return nodes.field("", fieldname, fieldbody)
TypedField.make_field = patched_make_field TypedField.make_field = patched_make_field
from argparse import ArgumentParser
import logging import logging
import pathlib import pathlib
from argparse import ArgumentParser
import torch import torch
import torchaudio import torchaudio
from lightning import RNNTModule from lightning import RNNTModule
...@@ -12,9 +11,7 @@ logger = logging.getLogger() ...@@ -12,9 +11,7 @@ logger = logging.getLogger()
def compute_word_level_distance(seq1, seq2): def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance( return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
seq1.lower().split(), seq2.lower().split()
)
def run_eval(args): def run_eval(args):
...@@ -38,9 +35,7 @@ def run_eval(args): ...@@ -38,9 +35,7 @@ def run_eval(args):
total_edit_distance += compute_word_level_distance(actual, predicted) total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split()) total_length += len(actual.split())
if idx % 100 == 0: if idx % 100 == 0:
logger.info( logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
f"Processed elem {idx}; WER: {total_edit_distance / total_length}"
)
logger.info(f"Final WER: {total_edit_distance / total_length}") logger.info(f"Final WER: {total_edit_distance / total_length}")
...@@ -58,13 +53,20 @@ def cli_main(): ...@@ -58,13 +53,20 @@ def cli_main():
help="Path to JSON file containing feature means and stddevs.", help="Path to JSON file containing feature means and stddevs.",
) )
parser.add_argument( parser.add_argument(
"--librispeech_path", type=pathlib.Path, help="Path to LibriSpeech datasets.", "--librispeech_path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
) )
parser.add_argument( parser.add_argument(
"--sp_model_path", type=pathlib.Path, help="Path to SentencePiece model.", "--sp_model_path",
type=pathlib.Path,
help="Path to SentencePiece model.",
) )
parser.add_argument( parser.add_argument(
"--use_cuda", action="store_true", default=False, help="Run using CUDA.", "--use_cuda",
action="store_true",
default=False,
help="Run using CUDA.",
) )
args = parser.parse_args() args = parser.parse_args()
run_eval(args) run_eval(args)
......
from collections import namedtuple
import json import json
import math import math
import os import os
from collections import namedtuple
from typing import List, Tuple from typing import List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
from pytorch_lightning import LightningModule
from torchaudio.prototype.rnnt import emformer_rnnt_base from torchaudio.prototype.rnnt import emformer_rnnt_base
from torchaudio.prototype.rnnt_decoder import Hypothesis, RNNTBeamSearch from torchaudio.prototype.rnnt_decoder import Hypothesis, RNNTBeamSearch
from pytorch_lightning import LightningModule
Batch = namedtuple( Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
"Batch", ["features", "feature_lengths", "targets", "target_lengths"]
)
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) _decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel) _gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram( _spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
sample_rate=16000, n_fft=400, n_mels=80, hop_length=160
)
def _batch_by_token_count(idx_target_lengths, token_limit): def _batch_by_token_count(idx_target_lengths, token_limit):
...@@ -61,9 +56,7 @@ class CustomDataset(torch.utils.data.Dataset): ...@@ -61,9 +56,7 @@ class CustomDataset(torch.utils.data.Dataset):
assert len(idx_target_lengths) > 0 assert len(idx_target_lengths) > 0
idx_target_lengths = sorted( idx_target_lengths = sorted(idx_target_lengths, key=lambda x: x[1], reverse=True)
idx_target_lengths, key=lambda x: x[1], reverse=True
)
assert max_token_limit >= idx_target_lengths[0][1] assert max_token_limit >= idx_target_lengths[0][1]
...@@ -74,9 +67,7 @@ class CustomDataset(torch.utils.data.Dataset): ...@@ -74,9 +67,7 @@ class CustomDataset(torch.utils.data.Dataset):
speaker_id, chapter_id, _ = fileid.split("-") speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + self.base_dataset._ext_txt file_text = speaker_id + "-" + chapter_id + self.base_dataset._ext_txt
file_text = os.path.join( file_text = os.path.join(self.base_dataset._path, speaker_id, chapter_id, file_text)
self.base_dataset._path, speaker_id, chapter_id, file_text
)
with open(file_text) as ft: with open(file_text) as ft:
for line in ft: for line in ft:
...@@ -93,24 +84,16 @@ class CustomDataset(torch.utils.data.Dataset): ...@@ -93,24 +84,16 @@ class CustomDataset(torch.utils.data.Dataset):
class TimeMasking(torchaudio.transforms._AxisMasking): class TimeMasking(torchaudio.transforms._AxisMasking):
def __init__( def __init__(self, time_mask_param: int, min_mask_p: float, iid_masks: bool = False) -> None:
self, time_mask_param: int, min_mask_p: float, iid_masks: bool = False
) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
self.min_mask_p = min_mask_p self.min_mask_p = min_mask_p
def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor: def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:
if self.iid_masks and specgram.dim() == 4: if self.iid_masks and specgram.dim() == 4:
mask_param = min( mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis + 1])
self.mask_param, self.min_mask_p * specgram.shape[self.axis + 1] return F.mask_along_axis_iid(specgram, mask_param, mask_value, self.axis + 1)
)
return F.mask_along_axis_iid(
specgram, mask_param, mask_value, self.axis + 1
)
else: else:
mask_param = min( mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis])
self.mask_param, self.min_mask_p * specgram.shape[self.axis]
)
return F.mask_along_axis(specgram, mask_param, mask_value, self.axis) return F.mask_along_axis(specgram, mask_param, mask_value, self.axis)
...@@ -149,10 +132,7 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler): ...@@ -149,10 +132,7 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self): def get_lr(self):
return [ return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]
(min(1.0, self._step_count / self.warmup_updates)) * base_lr
for base_lr in self.base_lrs
]
def post_process_hypos( def post_process_hypos(
...@@ -164,12 +144,7 @@ def post_process_hypos( ...@@ -164,12 +144,7 @@ def post_process_hypos(
sp_model.pad_id(), sp_model.pad_id(),
] ]
filtered_hypo_tokens = [ filtered_hypo_tokens = [
[ [token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
token_index
for token_index in h.tokens[1:]
if token_index not in post_process_remove_list
]
for h in hypos
] ]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos] hypos_ali = [h.alignment[1:] for h in hypos]
...@@ -193,12 +168,8 @@ class RNNTModule(LightningModule): ...@@ -193,12 +168,8 @@ class RNNTModule(LightningModule):
self.model = emformer_rnnt_base() self.model = emformer_rnnt_base()
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0) self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0)
self.optimizer = torch.optim.Adam( self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8 self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.96, patience=0)
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, factor=0.96, patience=0
)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000) self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential( self.train_data_pipeline = torch.nn.Sequential(
...@@ -236,27 +207,17 @@ class RNNTModule(LightningModule): ...@@ -236,27 +207,17 @@ class RNNTModule(LightningModule):
return targets, lengths return targets, lengths
def _train_extract_features(self, samples: List): def _train_extract_features(self, samples: List):
mel_features = [ mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
_spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
for sample in samples
]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True) features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features) features = self.train_data_pipeline(features)
lengths = torch.tensor( lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
[elem.shape[0] for elem in mel_features], dtype=torch.int32
)
return features, lengths return features, lengths
def _valid_extract_features(self, samples: List): def _valid_extract_features(self, samples: List):
mel_features = [ mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
_spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
for sample in samples
]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True) features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features) features = self.valid_data_pipeline(features)
lengths = torch.tensor( lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
[elem.shape[0] for elem in mel_features], dtype=torch.int32
)
return features, lengths return features, lengths
def _train_collate_fn(self, samples: List): def _train_collate_fn(self, samples: List):
...@@ -276,9 +237,7 @@ class RNNTModule(LightningModule): ...@@ -276,9 +237,7 @@ class RNNTModule(LightningModule):
if batch is None: if batch is None:
return None return None
prepended_targets = batch.targets.new_empty( prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
[batch.targets.size(0), batch.targets.size(1) + 1]
)
prepended_targets[:, 1:] = batch.targets prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1 prepended_target_lengths = batch.target_lengths + 1
...@@ -307,9 +266,7 @@ class RNNTModule(LightningModule): ...@@ -307,9 +266,7 @@ class RNNTModule(LightningModule):
def forward(self, batch: Batch): def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx) decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder( hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
batch.features.to(self.device), batch.feature_lengths.to(self.device), 20
)
return post_process_hypos(hypotheses, self.sp_model)[0][0] return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx): def training_step(self, batch: Batch, batch_idx):
...@@ -325,21 +282,15 @@ class RNNTModule(LightningModule): ...@@ -325,21 +282,15 @@ class RNNTModule(LightningModule):
dataset = torch.utils.data.ConcatDataset( dataset = torch.utils.data.ConcatDataset(
[ [
CustomDataset( CustomDataset(
torchaudio.datasets.LIBRISPEECH( torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
self.librispeech_path, url="train-clean-360"
),
1000, 1000,
), ),
CustomDataset( CustomDataset(
torchaudio.datasets.LIBRISPEECH( torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
self.librispeech_path, url="train-clean-100"
),
1000, 1000,
), ),
CustomDataset( CustomDataset(
torchaudio.datasets.LIBRISPEECH( torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
self.librispeech_path, url="train-other-500"
),
1000, 1000,
), ),
] ]
...@@ -357,29 +308,24 @@ class RNNTModule(LightningModule): ...@@ -357,29 +308,24 @@ class RNNTModule(LightningModule):
dataset = torch.utils.data.ConcatDataset( dataset = torch.utils.data.ConcatDataset(
[ [
CustomDataset( CustomDataset(
torchaudio.datasets.LIBRISPEECH( torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
self.librispeech_path, url="dev-clean"
),
1000, 1000,
), ),
CustomDataset( CustomDataset(
torchaudio.datasets.LIBRISPEECH( torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
self.librispeech_path, url="dev-other"
),
1000, 1000,
), ),
] ]
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, batch_size=None, collate_fn=self._valid_collate_fn, num_workers=10, dataset,
batch_size=None,
collate_fn=self._valid_collate_fn,
num_workers=10,
) )
return dataloader return dataloader
def test_dataloader(self): def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH( dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
self.librispeech_path, url="test-clean" dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, collate_fn=self._test_collate_fn
)
return dataloader return dataloader
from argparse import ArgumentParser
import pathlib import pathlib
from argparse import ArgumentParser
from lightning import RNNTModule
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from lightning import RNNTModule
def run_train(args): def run_train(args):
checkpoint_dir = args.exp_dir / "checkpoints" checkpoint_dir = args.exp_dir / "checkpoints"
...@@ -63,10 +62,14 @@ def cli_main(): ...@@ -63,10 +62,14 @@ def cli_main():
help="Path to JSON file containing feature means and stddevs.", help="Path to JSON file containing feature means and stddevs.",
) )
parser.add_argument( parser.add_argument(
"--librispeech_path", type=pathlib.Path, help="Path to LibriSpeech datasets.", "--librispeech_path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
) )
parser.add_argument( parser.add_argument(
"--sp_model_path", type=pathlib.Path, help="Path to SentencePiece model.", "--sp_model_path",
type=pathlib.Path,
help="Path to SentencePiece model.",
) )
parser.add_argument( parser.add_argument(
"--num_nodes", "--num_nodes",
......
import random
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Dict, Dict,
...@@ -7,11 +8,11 @@ from typing import ( ...@@ -7,11 +8,11 @@ from typing import (
Tuple, Tuple,
Union, Union,
) )
from torch import Tensor
import numpy as np import numpy as np
import random
import torch import torch
import torchaudio import torchaudio
from torch import Tensor
from torch.utils.data import Dataset, BatchSampler from torch.utils.data import Dataset, BatchSampler
...@@ -30,27 +31,22 @@ class BucketizeSampler(BatchSampler): ...@@ -30,27 +31,22 @@ class BucketizeSampler(BatchSampler):
the lengths of samples are unknown, the batch size may be different for different the lengths of samples are unknown, the batch size may be different for different
mini-batches. mini-batches.
""" """
def __init__( def __init__(
self, self,
data_source: Dataset, data_source: Dataset,
num_buckets: int, num_buckets: int,
max_token_count: Optional[int] = None, max_token_count: Optional[int] = None,
batch_size: Optional[int] = None batch_size: Optional[int] = None,
) -> None: ) -> None:
if max_token_count is not None and batch_size is not None: if max_token_count is not None and batch_size is not None:
raise AssertionError( raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.")
"The ``max_token_count`` and ``batch_size`` can't be both set."
)
self.data_source = data_source self.data_source = data_source
self.max_token_count = max_token_count self.max_token_count = max_token_count
self.batch_size = batch_size self.batch_size = batch_size
self.buckets = self._get_buckets(self.data_source, num_buckets) self.buckets = self._get_buckets(self.data_source, num_buckets)
def _get_buckets( def _get_buckets(self, data_source: Dataset, num_buckets: int) -> Dict[int, Tensor]:
self,
data_source: Dataset,
num_buckets: int
) -> Dict[int, Tensor]:
"""Generate buckets based on the dataset. """Generate buckets based on the dataset.
Args: Args:
data_source (Dataset): The dataset object to bucketize. data_source (Dataset): The dataset object to bucketize.
...@@ -126,6 +122,7 @@ class HuBERTDataSet(Dataset): ...@@ -126,6 +122,7 @@ class HuBERTDataSet(Dataset):
min_sample (int): The minimum number of audio samples in the dataset. (Default: 32000) min_sample (int): The minimum number of audio samples in the dataset. (Default: 32000)
max_sample (int): The maximum number of audio samples in the dataset. (Default: 250000) max_sample (int): The maximum number of audio samples in the dataset. (Default: 250000)
""" """
def __init__( def __init__(
self, self,
exp_dir: Union[str, Path], exp_dir: Union[str, Path],
...@@ -137,13 +134,7 @@ class HuBERTDataSet(Dataset): ...@@ -137,13 +134,7 @@ class HuBERTDataSet(Dataset):
self.exp_dir = Path(exp_dir) self.exp_dir = Path(exp_dir)
tsv_dir = self.exp_dir / "tsv" tsv_dir = self.exp_dir / "tsv"
label_dir = self.exp_dir / "label" label_dir = self.exp_dir / "label"
f_list, ind_list, len_list = self._get_lists( f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset, min_sample, max_sample)
tsv_dir,
dataset,
subset,
min_sample,
max_sample
)
self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list
self.labels = self._load_labels(label_dir, dataset, subset) self.labels = self._load_labels(label_dir, dataset, subset)
...@@ -188,10 +179,7 @@ class HuBERTDataSet(Dataset): ...@@ -188,10 +179,7 @@ class HuBERTDataSet(Dataset):
len_list.append(ele[2]) len_list.append(ele[2])
return np.asarray(f_list), np.asarray(ind_list), np.asarray(len_list) return np.asarray(f_list), np.asarray(ind_list), np.asarray(len_list)
def _load_audio( def _load_audio(self, index: int) -> Tensor:
self,
index: int
) -> Tensor:
"""Load waveform given the sample index of the dataset. """Load waveform given the sample index of the dataset.
Args: Args:
index (int): The sample index. index (int): The sample index.
...@@ -204,12 +192,7 @@ class HuBERTDataSet(Dataset): ...@@ -204,12 +192,7 @@ class HuBERTDataSet(Dataset):
assert waveform.shape[1] == self.len_list[index] assert waveform.shape[1] == self.len_list[index]
return waveform return waveform
def _load_labels( def _load_labels(self, label_dir: Path, dataset: str, subset: str) -> np.array:
self,
label_dir: Path,
dataset: str,
subset: str
) -> np.array:
"""Load all labels to memory into a numpy array. """Load all labels to memory into a numpy array.
Args: Args:
label_dir (Path): The directory that contains the label file. label_dir (Path): The directory that contains the label file.
...@@ -245,6 +228,7 @@ class CollateFnHubert: ...@@ -245,6 +228,7 @@ class CollateFnHubert:
waveform and label is random if the length is longer than the minimum waveform and label is random if the length is longer than the minimum
length in the mini-batch. length in the mini-batch.
""" """
def __init__( def __init__(
self, self,
feature_type: str, feature_type: str,
...@@ -284,7 +268,7 @@ class CollateFnHubert: ...@@ -284,7 +268,7 @@ class CollateFnHubert:
data = torch.zeros(len(batch), audio_size) data = torch.zeros(len(batch), audio_size)
for i in range(len(waveforms)): for i in range(len(waveforms)):
data[i][0:waveforms[i].shape[1]] = waveforms[i][0] data[i][0 : waveforms[i].shape[1]] = waveforms[i][0]
lengths = torch.tensor(lengths) lengths = torch.tensor(lengths)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
return data, labels, lengths return data, labels, lengths
...@@ -318,16 +302,10 @@ class CollateFnHubert: ...@@ -318,16 +302,10 @@ class CollateFnHubert:
diff = waveform.size(1) - audio_size diff = waveform.size(1) - audio_size
audio_start = torch.randint(diff, size=(1,)) if rand_crop else 0 audio_start = torch.randint(diff, size=(1,)) if rand_crop else 0
label_start = torch.div( label_start = torch.div(
audio_start - kernel_size * sample_rate, audio_start - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor"
stride * sample_rate,
rounding_mode='floor'
)
label_size = torch.div(
audio_size - kernel_size * sample_rate,
stride * sample_rate,
rounding_mode='floor'
) )
waveform = waveform[:, audio_start:audio_start + audio_size] label_size = torch.div(audio_size - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor")
label = label[label_start:label_start + label_size] waveform = waveform[:, audio_start : audio_start + audio_size]
label = label[label_start : label_start + label_size]
length = audio_size length = audio_size
return waveform, label, length return waveform, label, length
...@@ -21,9 +21,7 @@ from utils import ( ...@@ -21,9 +21,7 @@ from utils import (
def _init_logger(debug=False): def _init_logger(debug=False):
message_fmt = ( message_fmt = "%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
"%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
)
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG if debug else logging.INFO, level=logging.DEBUG if debug else logging.INFO,
format=f"%(asctime)s: {message_fmt}", format=f"%(asctime)s: {message_fmt}",
...@@ -84,15 +82,17 @@ def main(args): ...@@ -84,15 +82,17 @@ def main(args):
for split in ["train", "valid"]: for split in ["train", "valid"]:
p = Pool(args.num_rank) p = Pool(args.num_rank)
inputs = [( inputs = [
tsv_dir / f"{args.dataset}_{split}.tsv", (
feat_dir, tsv_dir / f"{args.dataset}_{split}.tsv",
split, feat_dir,
rank, split,
args.num_rank, rank,
device, args.num_rank,
args.feat_type, device,
16_000,) args.feat_type,
16_000,
)
for rank in range(args.num_rank) for rank in range(args.num_rank)
] ]
_ = p.starmap(dump_features, inputs) _ = p.starmap(dump_features, inputs)
......
...@@ -37,7 +37,7 @@ def create_tsv( ...@@ -37,7 +37,7 @@ def create_tsv(
[``librispeech``, ``libri-light``]. (Default: ``librispeech``) [``librispeech``, ``libri-light``]. (Default: ``librispeech``)
valid_percent (float, optional): The percentage of data for validation. (Default: 0.01) valid_percent (float, optional): The percentage of data for validation. (Default: 0.01)
seed (int): The seed for randomly selecting the validation files. seed (int): The seed for randomly selecting the validation files.
extension (str, optional): The extention of audio files. (Default: ``flac``) extension (str, optional): The extension of audio files. (Default: ``flac``)
Returns: Returns:
None None
...@@ -51,11 +51,7 @@ def create_tsv( ...@@ -51,11 +51,7 @@ def create_tsv(
if not out_dir.exists(): if not out_dir.exists():
out_dir.mkdir() out_dir.mkdir()
valid_f = ( valid_f = open(out_dir / f"{dataset}_valid.tsv", "w") if valid_percent > 0 else None
open(out_dir / f"{dataset}_valid.tsv", "w")
if valid_percent > 0
else None
)
search_pattern = ".*train.*" search_pattern = ".*train.*"
with open(out_dir / f"{dataset}_train.tsv", "w") as train_f: with open(out_dir / f"{dataset}_train.tsv", "w") as train_f:
print(root_dir, file=train_f) print(root_dir, file=train_f)
...@@ -67,20 +63,13 @@ def create_tsv( ...@@ -67,20 +63,13 @@ def create_tsv(
if re.match(search_pattern, str(fname)): if re.match(search_pattern, str(fname)):
frames = torchaudio.info(fname).num_frames frames = torchaudio.info(fname).num_frames
dest = train_f if torch.rand(1) > valid_percent else valid_f dest = train_f if torch.rand(1) > valid_percent else valid_f
print( print(f"{fname.relative_to(root_dir)}\t{frames}", file=dest)
f"{fname.relative_to(root_dir)}\t{frames}", file=dest
)
if valid_f is not None: if valid_f is not None:
valid_f.close() valid_f.close()
_LG.info("Finished creating the file lists successfully") _LG.info("Finished creating the file lists successfully")
def _get_feat_lens_paths( def _get_feat_lens_paths(feat_dir: Path, split: str, rank: int, num_rank: int) -> Tuple[Path, Path]:
feat_dir: Path,
split: str,
rank: int,
num_rank: int
) -> Tuple[Path, Path]:
r"""Get the feature and lengths paths based on feature directory, r"""Get the feature and lengths paths based on feature directory,
data split, rank, and number of ranks. data split, rank, and number of ranks.
Args: Args:
...@@ -99,9 +88,7 @@ def _get_feat_lens_paths( ...@@ -99,9 +88,7 @@ def _get_feat_lens_paths(
return feat_path, len_path return feat_path, len_path
def _get_model_path( def _get_model_path(km_dir: Path) -> Path:
km_dir: Path
) -> Path:
r"""Get the file path of the KMeans clustering model r"""Get the file path of the KMeans clustering model
Args: Args:
km_dir (Path): The directory to store the KMeans clustering model. km_dir (Path): The directory to store the KMeans clustering model.
......
...@@ -19,11 +19,7 @@ from .common_utils import _get_feat_lens_paths ...@@ -19,11 +19,7 @@ from .common_utils import _get_feat_lens_paths
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
def get_shard_range( def get_shard_range(num_lines: int, num_rank: int, rank: int) -> Tuple[int, int]:
num_lines: int,
num_rank: int,
rank: int
) -> Tuple[int, int]:
r"""Get the range of indices for the current rank in multi-processing. r"""Get the range of indices for the current rank in multi-processing.
Args: Args:
num_lines (int): The number of lines to process. num_lines (int): The number of lines to process.
...@@ -39,10 +35,7 @@ def get_shard_range( ...@@ -39,10 +35,7 @@ def get_shard_range(
assert num_lines > 0, f"Found {num_lines} files, make sure you specify the correct root directory" assert num_lines > 0, f"Found {num_lines} files, make sure you specify the correct root directory"
start = round(num_lines / num_rank * rank) start = round(num_lines / num_rank * rank)
end = round(num_lines / num_rank * (rank + 1)) end = round(num_lines / num_rank * (rank + 1))
_LG.info( _LG.info(f"rank {rank} of {num_rank}, process {end-start} " f"({start}-{end}) out of {num_lines}")
f"rank {rank} of {num_rank}, process {end-start} "
f"({start}-{end}) out of {num_lines}"
)
return start, end return start, end
...@@ -68,9 +61,7 @@ def extract_feature( ...@@ -68,9 +61,7 @@ def extract_feature(
waveform = waveform[0].to(device) waveform = waveform[0].to(device)
if feature_type == "mfcc": if feature_type == "mfcc":
feature_extractor = torchaudio.transforms.MFCC( feature_extractor = torchaudio.transforms.MFCC(
sample_rate=sample_rate, sample_rate=sample_rate, n_mfcc=13, melkwargs={"n_fft": 400, "hop_length": 160, "center": False}
n_mfcc=13,
melkwargs={'n_fft': 400, 'hop_length': 160, 'center': False}
).to(device) ).to(device)
mfccs = feature_extractor(waveform) # (freq, time) mfccs = feature_extractor(waveform) # (freq, time)
# mfccs = torchaudio.compliance.kaldi.mfcc( # mfccs = torchaudio.compliance.kaldi.mfcc(
......
...@@ -126,11 +126,7 @@ class ApplyKmeans(object): ...@@ -126,11 +126,7 @@ class ApplyKmeans(object):
self.Cnorm = torch.from_numpy(self.Cnorm_np).to(device) self.Cnorm = torch.from_numpy(self.Cnorm_np).to(device)
def __call__(self, x): def __call__(self, x):
dist = ( dist = x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
x.pow(2).sum(1, keepdim=True)
- 2 * torch.matmul(x, self.C)
+ self.Cnorm
)
return dist.argmin(dim=1).cpu().numpy() return dist.argmin(dim=1).cpu().numpy()
...@@ -171,7 +167,7 @@ def get_km_label( ...@@ -171,7 +167,7 @@ def get_km_label(
assert feats.shape[0] == lens.sum() assert feats.shape[0] == lens.sum()
with open(label_path, "w") as f: with open(label_path, "w") as f:
for i in range(lens.shape[0]): for i in range(lens.shape[0]):
feat = feats[offset:offset + lens[i]].to(device) feat = feats[offset : offset + lens[i]].to(device)
offset += lens[i] offset += lens[i]
label = apply_kmeans(feat).tolist() label = apply_kmeans(feat).tolist()
f.write(" ".join(map(str, label)) + "\n") f.write(" ".join(map(str, label)) + "\n")
......
from . import utils, vad from . import utils, vad
__all__ = ['utils', 'vad'] __all__ = ["utils", "vad"]
...@@ -13,7 +13,6 @@ import datetime as dt ...@@ -13,7 +13,6 @@ import datetime as dt
import logging import logging
from fairseq import options from fairseq import options
from interactive_asr.utils import add_asr_eval_argument, setup_asr, get_microphone_transcription, transcribe_file from interactive_asr.utils import add_asr_eval_argument, setup_asr, get_microphone_transcription, transcribe_file
...@@ -29,11 +28,7 @@ def main(args): ...@@ -29,11 +28,7 @@ def main(args):
print("transcription_time:", transcription_time) print("transcription_time:", transcription_time)
else: else:
for transcription in get_microphone_transcription(args, task, generator, models, sp, tgt_dict): for transcription in get_microphone_transcription(args, task, generator, models, sp, tgt_dict):
print( print("{}: {}".format(dt.datetime.now().strftime("%H:%M:%S"), transcription[0][0]))
"{}: {}".format(
dt.datetime.now().strftime("%H:%M:%S"), transcription[0][0]
)
)
def cli_main(): def cli_main():
......
...@@ -9,13 +9,11 @@ import os ...@@ -9,13 +9,11 @@ import os
import sys import sys
import time import time
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
import sentencepiece as spm
from fairseq import tasks from fairseq import tasks
from fairseq.utils import load_ensemble_for_inference, import_user_module from fairseq.utils import load_ensemble_for_inference, import_user_module
from interactive_asr.vad import get_microphone_chunks from interactive_asr.vad import get_microphone_chunks
...@@ -24,9 +22,7 @@ def add_asr_eval_argument(parser): ...@@ -24,9 +22,7 @@ def add_asr_eval_argument(parser):
parser.add_argument("--ctc", action="store_true", help="decode a ctc model") parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
parser.add_argument("--rnnt", default=False, help="decode a rnnt model") parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
parser.add_argument("--kspmodel", default=None, help="sentence piece model") parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument( parser.add_argument("--wfstlm", default=None, help="wfstlm on dictonary output units")
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument( parser.add_argument(
"--rnnt_decoding_type", "--rnnt_decoding_type",
default="greedy", default="greedy",
...@@ -37,20 +33,14 @@ def add_asr_eval_argument(parser): ...@@ -37,20 +33,14 @@ def add_asr_eval_argument(parser):
default=0.2, default=0.2,
help="weight for wfstlm while interpolating with neural score", help="weight for wfstlm while interpolating with neural score",
) )
parser.add_argument( parser.add_argument("--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level")
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
return parser return parser
def check_args(args): def check_args(args):
assert args.path is not None, "--path required for generation!" assert args.path is not None, "--path required for generation!"
assert ( assert not args.sampling or args.nbest == args.beam, "--sampling requires --nbest to be equal to --beam"
not args.sampling or args.nbest == args.beam assert args.replace_unk is None or args.raw_text, "--replace-unk requires a raw text dataset (--raw-text)"
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
def process_predictions(args, hypos, sp, tgt_dict): def process_predictions(args, hypos, sp, tgt_dict):
...@@ -64,8 +54,7 @@ def process_predictions(args, hypos, sp, tgt_dict): ...@@ -64,8 +54,7 @@ def process_predictions(args, hypos, sp, tgt_dict):
def optimize_models(args, use_cuda, models): def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation """Optimize ensemble for generation"""
"""
for model in models: for model in models:
model.make_generation_fast_( model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
...@@ -166,24 +155,16 @@ def transcribe_file(args, task, generator, models, sp, tgt_dict): ...@@ -166,24 +155,16 @@ def transcribe_file(args, task, generator, models, sp, tgt_dict):
raise FileNotFoundError("Audio file not found: {}".format(path)) raise FileNotFoundError("Audio file not found: {}".format(path))
waveform, sample_rate = torchaudio.load_wav(path) waveform, sample_rate = torchaudio.load_wav(path)
waveform = waveform.mean(0, True) waveform = waveform.mean(0, True)
waveform = torchaudio.transforms.Resample( waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
orig_freq=sample_rate, new_freq=16000
)(waveform)
start = time.time() start = time.time()
transcription = transcribe( transcription = transcribe(waveform, args, task, generator, models, sp, tgt_dict)
waveform, args, task, generator, models, sp, tgt_dict
)
transcription_time = time.time() - start transcription_time = time.time() - start
return transcription_time, transcription return transcription_time, transcription
def get_microphone_transcription(args, task, generator, models, sp, tgt_dict): def get_microphone_transcription(args, task, generator, models, sp, tgt_dict):
for (waveform, sample_rate) in get_microphone_chunks(): for (waveform, sample_rate) in get_microphone_chunks():
waveform = torchaudio.transforms.Resample( waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform.reshape(1, -1))
orig_freq=sample_rate, new_freq=16000 transcription = transcribe(waveform, args, task, generator, models, sp, tgt_dict)
)(waveform.reshape(1, -1))
transcription = transcribe(
waveform, args, task, generator, models, sp, tgt_dict
)
yield transcription yield transcription
...@@ -17,14 +17,13 @@ speech sequences. In the online case here, inertia is added before switching ...@@ -17,14 +17,13 @@ speech sequences. In the online case here, inertia is added before switching
from speech to silence or vice versa. from speech to silence or vice versa.
""" """
from collections import deque
import numpy as np
import torch
import queue import queue
from collections import deque
import librosa import librosa
import numpy as np
import pyaudio import pyaudio
import torch
import torchaudio import torchaudio
...@@ -95,9 +94,7 @@ class VoiceActivityDetection: ...@@ -95,9 +94,7 @@ class VoiceActivityDetection:
elif self.n < self.num_init_frames: elif self.n < self.num_init_frames:
self.min_energy = min(energy, self.min_energy) self.min_energy = min(energy, self.min_energy)
self.min_frequency = min(frequency, self.min_frequency) self.min_frequency = min(frequency, self.min_frequency)
self.min_spectral_flatness = min( self.min_spectral_flatness = min(spectral_flatness, self.min_spectral_flatness)
spectral_flatness, self.min_spectral_flatness
)
self.n += 1 self.n += 1
...@@ -121,10 +118,7 @@ class VoiceActivityDetection: ...@@ -121,10 +118,7 @@ class VoiceActivityDetection:
# Speech detected # Speech detected
self.speech_count += 1 self.speech_count += 1
# Inertia against switching # Inertia against switching
if ( if self.n >= self.num_init_frames and self.speech_count <= self.ignore_speech_count:
self.n >= self.num_init_frames
and self.speech_count <= self.ignore_speech_count
):
# Too soon to change # Too soon to change
return self.silence_mark return self.silence_mark
else: else:
...@@ -132,15 +126,10 @@ class VoiceActivityDetection: ...@@ -132,15 +126,10 @@ class VoiceActivityDetection:
return self.speech_mark return self.speech_mark
else: else:
# Silence detected # Silence detected
self.min_energy = ((self.silent_count * self.min_energy) + energy) / ( self.min_energy = ((self.silent_count * self.min_energy) + energy) / (self.silent_count + 1)
self.silent_count + 1
)
self.silent_count += 1 self.silent_count += 1
# Inertia against switching # Inertia against switching
if ( if self.n >= self.num_init_frames and self.silent_count <= self.ignore_silent_count:
self.n >= self.num_init_frames
and self.silent_count <= self.ignore_silent_count
):
# Too soon to change # Too soon to change
return self.speech_mark return self.speech_mark
else: else:
...@@ -260,9 +249,7 @@ def get_microphone_chunks( ...@@ -260,9 +249,7 @@ def get_microphone_chunks(
else: else:
precumulated.append(chunk) precumulated.append(chunk)
if (not is_speech and len(cumulated) >= min_to_cumulate) or ( if (not is_speech and len(cumulated) >= min_to_cumulate) or (len(cumulated) > max_to_cumulate):
len(cumulated) > max_to_cumulate
):
waveform = torch.cat(list(precumulated) + cumulated, -1) waveform = torch.cat(list(precumulated) + cumulated, -1)
yield (waveform * stream._rate, stream._rate) yield (waveform * stream._rate, stream._rate)
cumulated = [] cumulated = []
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
""" """
Create a data preprocess pipeline that can be run with libtorchaudio Create a data preprocess pipeline that can be run with libtorchaudio
""" """
import os
import argparse import argparse
import os
import torch import torch
import torchaudio import torchaudio
...@@ -14,10 +14,11 @@ class Pipeline(torch.nn.Module): ...@@ -14,10 +14,11 @@ class Pipeline(torch.nn.Module):
This example load waveform from a file then apply effects and save it to a file. This example load waveform from a file then apply effects and save it to a file.
""" """
def __init__(self, rir_path: str): def __init__(self, rir_path: str):
super().__init__() super().__init__()
rir, sample_rate = torchaudio.load(rir_path) rir, sample_rate = torchaudio.load(rir_path)
self.register_buffer('rir', rir) self.register_buffer("rir", rir)
self.rir_sample_rate: int = sample_rate self.rir_sample_rate: int = sample_rate
def forward(self, input_path: str, output_path: str): def forward(self, input_path: str, output_path: str):
...@@ -32,7 +33,8 @@ class Pipeline(torch.nn.Module): ...@@ -32,7 +33,8 @@ class Pipeline(torch.nn.Module):
# 3. Reample the RIR filter to much the audio sample rate # 3. Reample the RIR filter to much the audio sample rate
rir, _ = torchaudio.sox_effects.apply_effects_tensor( rir, _ = torchaudio.sox_effects.apply_effects_tensor(
self.rir, self.rir_sample_rate, effects=[["rate", str(sample_rate)]]) self.rir, self.rir_sample_rate, effects=[["rate", str(sample_rate)]]
)
rir = rir / torch.norm(rir, p=2) rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1]) rir = torch.flip(rir, [1])
...@@ -62,15 +64,9 @@ def _get_path(*paths): ...@@ -62,15 +64,9 @@ def _get_path(*paths):
def _parse_args(): def _parse_args():
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--rir-path", "--rir-path", default=_get_path("..", "data", "rir.wav"), help="Audio dara for room impulse response."
default=_get_path("..", "data", "rir.wav"),
help="Audio dara for room impulse response."
)
parser.add_argument(
"--output-path",
default=_get_path("pipeline.zip"),
help="Output JIT file."
) )
parser.add_argument("--output-path", default=_get_path("pipeline.zip"), help="Output JIT file.")
return parser.parse_args() return parser.parse_args()
...@@ -79,5 +75,5 @@ def _main(): ...@@ -79,5 +75,5 @@ def _main():
_create_jit_pipeline(args.rir_path, args.output_path) _create_jit_pipeline(args.rir_path, args.output_path)
if __name__ == '__main__': if __name__ == "__main__":
_main() _main()
...@@ -3,18 +3,17 @@ ...@@ -3,18 +3,17 @@
To use this script, you need `fairseq`. To use this script, you need `fairseq`.
""" """
import os
import argparse import argparse
import logging import logging
import os
from typing import Tuple from typing import Tuple
import fairseq
import torch import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
import torchaudio import torchaudio
from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model
import fairseq
from greedy_decoder import Decoder from greedy_decoder import Decoder
from torch.utils.mobile_optimizer import optimize_for_mobile
from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10): if TORCH_VERSION >= (1, 10):
...@@ -29,44 +28,31 @@ def _parse_args(): ...@@ -29,44 +28,31 @@ def _parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, description=__doc__,
) )
parser.add_argument("--model-file", required=True, help="Path to the input pretrained weight file.")
parser.add_argument( parser.add_argument(
'--model-file', "--dict-dir",
required=True,
help='Path to the input pretrained weight file.'
)
parser.add_argument(
'--dict-dir',
help=( help=(
'Path to the directory in which `dict.ltr.txt` file is found. ' "Path to the directory in which `dict.ltr.txt` file is found. " "Required only when the model is finetuned."
'Required only when the model is finetuned.' ),
)
) )
parser.add_argument( parser.add_argument(
'--output-path', "--output-path",
help='Path to the directory, where the TorchScript-ed pipelines are saved.', help="Path to the directory, where the TorchScript-ed pipelines are saved.",
) )
parser.add_argument( parser.add_argument(
'--test-file', "--test-file",
help='Path to a test audio file.', help="Path to a test audio file.",
) )
parser.add_argument( parser.add_argument(
'--debug', "--debug",
action='store_true', action="store_true",
help=( help=(
'When enabled, individual components are separately tested ' "When enabled, individual components are separately tested "
'for the numerical compatibility and TorchScript compatibility.' "for the numerical compatibility and TorchScript compatibility."
) ),
)
parser.add_argument(
'--quantize',
action='store_true',
help='Apply quantization to model.'
)
parser.add_argument(
'--optimize-for-mobile',
action='store_true',
help='Apply optmization for mobile.'
) )
parser.add_argument("--quantize", action="store_true", help="Apply quantization to model.")
parser.add_argument("--optimize-for-mobile", action="store_true", help="Apply optmization for mobile.")
return parser.parse_args() return parser.parse_args()
...@@ -74,7 +60,7 @@ class Loader(torch.nn.Module): ...@@ -74,7 +60,7 @@ class Loader(torch.nn.Module):
def forward(self, audio_path: str) -> torch.Tensor: def forward(self, audio_path: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_path) waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000: if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.) waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.0)
return waveform return waveform
...@@ -129,11 +115,9 @@ def _get_decoder(): ...@@ -129,11 +115,9 @@ def _get_decoder():
def _load_fairseq_model(input_file, data_dir=None): def _load_fairseq_model(input_file, data_dir=None):
overrides = {} overrides = {}
if data_dir: if data_dir:
overrides['data'] = data_dir overrides["data"] = data_dir
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([input_file], arg_overrides=overrides)
[input_file], arg_overrides=overrides
)
model = model[0] model = model[0]
return model return model
...@@ -154,36 +138,32 @@ def _main(): ...@@ -154,36 +138,32 @@ def _main():
_LG.info(encoder) _LG.info(encoder)
if args.quantize: if args.quantize:
_LG.info('Quantizing the model') _LG.info("Quantizing the model")
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = tq.quantize_dynamic( encoder = tq.quantize_dynamic(encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder) _LG.info(encoder)
# test # test
if args.test_file: if args.test_file:
_LG.info('Testing with %s', args.test_file) _LG.info("Testing with %s", args.test_file)
waveform = loader(args.test_file) waveform = loader(args.test_file)
emission = encoder(waveform) emission = encoder(waveform)
transcript = decoder(emission) transcript = decoder(emission)
_LG.info(transcript) _LG.info(transcript)
torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip')) torch.jit.script(loader).save(os.path.join(args.output_path, "loader.zip"))
torch.jit.script(decoder).save(os.path.join(args.output_path, 'decoder.zip')) torch.jit.script(decoder).save(os.path.join(args.output_path, "decoder.zip"))
scripted = torch.jit.script(encoder) scripted = torch.jit.script(encoder)
if args.optimize_for_mobile: if args.optimize_for_mobile:
scripted = optimize_for_mobile(scripted) scripted = optimize_for_mobile(scripted)
scripted.save(os.path.join(args.output_path, 'encoder.zip')) scripted.save(os.path.join(args.output_path, "encoder.zip"))
def _init_logging(debug=False): def _init_logging(debug=False):
level = logging.DEBUG if debug else logging.INFO level = logging.DEBUG if debug else logging.INFO
format_ = ( format_ = "%(message)s" if not debug else "%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s"
'%(message)s' if not debug else
'%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s'
)
logging.basicConfig(level=level, format=format_) logging.basicConfig(level=level, format=format_)
if __name__ == '__main__': if __name__ == "__main__":
_main() _main()
...@@ -6,8 +6,8 @@ from typing import Tuple ...@@ -6,8 +6,8 @@ from typing import Tuple
import torch import torch
import torchaudio import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
from greedy_decoder import Decoder from greedy_decoder import Decoder
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10): if TORCH_VERSION >= (1, 10):
...@@ -22,31 +22,27 @@ def _parse_args(): ...@@ -22,31 +22,27 @@ def _parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, description=__doc__,
) )
parser.add_argument("--model", required=True, help="Path to the input pretrained weight file.")
parser.add_argument( parser.add_argument(
'--model', "--output-path",
required=True, help="Path to the directory, where the Torchscript-ed pipelines are saved.",
help='Path to the input pretrained weight file.'
)
parser.add_argument(
'--output-path',
help='Path to the directory, where the Torchscript-ed pipelines are saved.',
) )
parser.add_argument( parser.add_argument(
'--test-file', "--test-file",
help='Path to a test audio file.', help="Path to a test audio file.",
) )
parser.add_argument( parser.add_argument(
'--quantize', "--quantize",
action='store_true', action="store_true",
help='Quantize the model.', help="Quantize the model.",
) )
parser.add_argument( parser.add_argument(
'--debug', "--debug",
action='store_true', action="store_true",
help=( help=(
'When enabled, individual components are separately tested ' "When enabled, individual components are separately tested "
'for the numerical compatibility and TorchScript compatibility.' "for the numerical compatibility and TorchScript compatibility."
) ),
) )
return parser.parse_args() return parser.parse_args()
...@@ -55,7 +51,7 @@ class Loader(torch.nn.Module): ...@@ -55,7 +51,7 @@ class Loader(torch.nn.Module):
def forward(self, audio_path: str) -> torch.Tensor: def forward(self, audio_path: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_path) waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000: if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.) waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.0)
return waveform return waveform
...@@ -71,6 +67,7 @@ class Encoder(torch.nn.Module): ...@@ -71,6 +67,7 @@ class Encoder(torch.nn.Module):
def _get_model(model_id): def _get_model(model_id):
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
tokenizer = Wav2Vec2Processor.from_pretrained(model_id).tokenizer tokenizer = Wav2Vec2Processor.from_pretrained(model_id).tokenizer
labels = [k for k, v in sorted(tokenizer.get_vocab().items(), key=lambda kv: kv[1])] labels = [k for k, v in sorted(tokenizer.get_vocab().items(), key=lambda kv: kv[1])]
original = Wav2Vec2ForCTC.from_pretrained(model_id) original = Wav2Vec2ForCTC.from_pretrained(model_id)
...@@ -85,43 +82,39 @@ def _get_decoder(labels): ...@@ -85,43 +82,39 @@ def _get_decoder(labels):
def _main(): def _main():
args = _parse_args() args = _parse_args()
_init_logging(args.debug) _init_logging(args.debug)
_LG.info('Loading model: %s', args.model) _LG.info("Loading model: %s", args.model)
model, labels = _get_model(args.model) model, labels = _get_model(args.model)
_LG.info('Labels: %s', labels) _LG.info("Labels: %s", labels)
_LG.info('Building pipeline') _LG.info("Building pipeline")
loader = Loader() loader = Loader()
encoder = Encoder(model) encoder = Encoder(model)
decoder = _get_decoder(labels) decoder = _get_decoder(labels)
_LG.info(encoder) _LG.info(encoder)
if args.quantize: if args.quantize:
_LG.info('Quantizing the model') _LG.info("Quantizing the model")
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = tq.quantize_dynamic( encoder = tq.quantize_dynamic(encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder) _LG.info(encoder)
# test # test
if args.test_file: if args.test_file:
_LG.info('Testing with %s', args.test_file) _LG.info("Testing with %s", args.test_file)
waveform = loader(args.test_file) waveform = loader(args.test_file)
emission = encoder(waveform) emission = encoder(waveform)
transcript = decoder(emission) transcript = decoder(emission)
_LG.info(transcript) _LG.info(transcript)
torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip')) torch.jit.script(loader).save(os.path.join(args.output_path, "loader.zip"))
torch.jit.script(encoder).save(os.path.join(args.output_path, 'encoder.zip')) torch.jit.script(encoder).save(os.path.join(args.output_path, "encoder.zip"))
torch.jit.script(decoder).save(os.path.join(args.output_path, 'decoder.zip')) torch.jit.script(decoder).save(os.path.join(args.output_path, "decoder.zip"))
def _init_logging(debug=False): def _init_logging(debug=False):
level = logging.DEBUG if debug else logging.INFO level = logging.DEBUG if debug else logging.INFO
format_ = ( format_ = "%(message)s" if not debug else "%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s"
'%(message)s' if not debug else
'%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s'
)
logging.basicConfig(level=level, format=format_) logging.basicConfig(level=level, format=format_)
if __name__ == '__main__': if __name__ == "__main__":
_main() _main()
...@@ -17,12 +17,12 @@ class Decoder(torch.nn.Module): ...@@ -17,12 +17,12 @@ class Decoder(torch.nn.Module):
""" """
best_path = torch.argmax(logits, dim=-1) # [num_seq,] best_path = torch.argmax(logits, dim=-1) # [num_seq,]
best_path = torch.unique_consecutive(best_path, dim=-1) best_path = torch.unique_consecutive(best_path, dim=-1)
hypothesis = '' hypothesis = ""
for i in best_path: for i in best_path:
char = self.labels[i] char = self.labels[i]
if char in ['<s>', '<pad>']: if char in ["<s>", "<pad>"]:
continue continue
if char == '|': if char == "|":
char = ' ' char = " "
hypothesis += char hypothesis += char
return hypothesis return hypothesis
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment