Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
......@@ -14,22 +14,25 @@ See this comment for design rationale:
https://github.com/pytorch/vision/pull/1321#issuecomment-531033978
"""
import os.path
import jinja2
from jinja2 import select_autoescape
import yaml
import os.path
from jinja2 import select_autoescape
PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]
CU_VERSIONS_DICT = {"linux": ["cpu", "cu102", "cu111","cu113", "cu115", "rocm4.1"],
CU_VERSIONS_DICT = {
"linux": ["cpu", "cu102", "cu111", "cu113", "cu115", "rocm4.1"],
"windows": ["cpu", "cu113", "cu115"],
"macos": ["cpu"]}
"macos": ["cpu"],
}
DOC_VERSION = ('linux', '3.8')
DOC_VERSION = ("linux", "3.8")
def build_workflows(prefix='', upload=False, filter_branch=None, indentation=6):
def build_workflows(prefix="", upload=False, filter_branch=None, indentation=6):
w = []
w += build_download_job(filter_branch)
for btype in ["wheel", "conda"]:
......@@ -37,23 +40,21 @@ def build_workflows(prefix='', upload=False, filter_branch=None, indentation=6):
for python_version in PYTHON_VERSIONS:
for cu_version in CU_VERSIONS_DICT[os_type]:
fb = filter_branch
if cu_version.startswith("rocm") and btype=="conda":
if cu_version.startswith("rocm") and btype == "conda":
continue
if not fb and (os_type == 'linux' and
btype == 'wheel' and
python_version == '3.8' and
cu_version == 'cpu'):
if not fb and (
os_type == "linux" and btype == "wheel" and python_version == "3.8" and cu_version == "cpu"
):
# the fields must match the build_docs "requires" dependency
fb = '/.*/'
fb = "/.*/"
w += build_workflow_pair(btype, os_type, python_version, cu_version, fb, prefix, upload)
if not filter_branch:
# Build on every pull request, but upload only on nightly and tags
w += build_doc_job('/.*/')
w += upload_doc_job('nightly')
w += build_doc_job("/.*/")
w += upload_doc_job("nightly")
w += docstring_parameters_sync_job(None)
return indent(indentation, w)
......@@ -67,7 +68,7 @@ def build_download_job(filter_branch):
return [{"download_third_parties_nix": job}]
def build_workflow_pair(btype, os_type, python_version, cu_version, filter_branch, prefix='', upload=False):
def build_workflow_pair(btype, os_type, python_version, cu_version, filter_branch, prefix="", upload=False):
w = []
base_workflow_name = f"{prefix}binary_{os_type}_{btype}_py{python_version}_{cu_version}"
......@@ -77,9 +78,13 @@ def build_workflow_pair(btype, os_type, python_version, cu_version, filter_branc
w.append(generate_upload_workflow(base_workflow_name, filter_branch, os_type, btype, cu_version))
if filter_branch == 'nightly' and os_type != 'macos':
pydistro = 'pip' if btype == 'wheel' else 'conda'
w.append(generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, python_version, cu_version, os_type))
if filter_branch == "nightly" and os_type != "macos":
pydistro = "pip" if btype == "wheel" else "conda"
w.append(
generate_smoketest_workflow(
pydistro, base_workflow_name, filter_branch, python_version, cu_version, os_type
)
)
return w
......@@ -88,7 +93,9 @@ def build_doc_job(filter_branch):
job = {
"name": "build_docs",
"python_version": "3.8",
"requires": ["binary_linux_wheel_py3.8_cpu", ],
"requires": [
"binary_linux_wheel_py3.8_cpu",
],
}
if filter_branch:
......@@ -101,7 +108,9 @@ def upload_doc_job(filter_branch):
"name": "upload_docs",
"context": "org-member",
"python_version": "3.8",
"requires": ["build_docs", ],
"requires": [
"build_docs",
],
}
if filter_branch:
......@@ -113,7 +122,9 @@ def docstring_parameters_sync_job(filter_branch):
job = {
"name": "docstring_parameters_sync",
"python_version": "3.8",
"requires": ["binary_linux_wheel_py3.8_cpu", ],
"requires": [
"binary_linux_wheel_py3.8_cpu",
],
}
if filter_branch:
......@@ -129,13 +140,13 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, filte
"cuda_version": cu_version,
}
if os_type in ['linux', 'macos']:
d['requires'] = ['download_third_parties_nix']
if btype == 'conda':
d['conda_docker_image'] = f'pytorch/conda-builder:{cu_version.replace("cu1","cuda1")}'
elif cu_version.startswith('cu'):
d['wheel_docker_image'] = f'pytorch/manylinux-{cu_version.replace("cu1","cuda1")}'
elif cu_version.startswith('rocm'):
if os_type in ["linux", "macos"]:
d["requires"] = ["download_third_parties_nix"]
if btype == "conda":
d["conda_docker_image"] = f'pytorch/conda-builder:{cu_version.replace("cu1","cuda1")}'
elif cu_version.startswith("cu"):
d["wheel_docker_image"] = f'pytorch/manylinux-{cu_version.replace("cu1","cuda1")}'
elif cu_version.startswith("rocm"):
d["wheel_docker_image"] = f"pytorch/manylinux-rocm:{cu_version[len('rocm'):]}"
if filter_branch:
......@@ -153,7 +164,7 @@ def gen_filter_branch_tree(*branches):
# Using a raw string here to avoid having to escape
# anything
"only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/"
}
},
}
......@@ -164,9 +175,8 @@ def generate_upload_workflow(base_workflow_name, filter_branch, os_type, btype,
"requires": [base_workflow_name],
}
if btype == 'wheel':
d["subfolder"] = "" if os_type == 'macos' else cu_version + "/"
if btype == "wheel":
d["subfolder"] = "" if os_type == "macos" else cu_version + "/"
if filter_branch:
d["filters"] = gen_filter_branch_tree(filter_branch)
......@@ -212,22 +222,24 @@ def unittest_workflows(indentation=6):
job = {
"name": f"unittest_{os_type}_{device_type}_py{python_version}",
"python_version": python_version,
"cuda_version": 'cpu' if device_type == "cpu" else "cu113",
"cuda_version": "cpu" if device_type == "cpu" else "cu113",
}
if os_type != "windows":
job['requires'] = ['download_third_parties_nix']
job["requires"] = ["download_third_parties_nix"]
jobs.append({f"unittest_{os_type}_{device_type}": job})
if i == 0 and os_type == "linux" and device_type == "cpu":
jobs.append({
jobs.append(
{
"stylecheck": {
"name": f"stylecheck_py{python_version}",
"python_version": python_version,
"cuda_version": "cpu",
}
})
}
)
return indent(indentation, jobs)
......@@ -236,12 +248,14 @@ if __name__ == "__main__":
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(d),
lstrip_blocks=True,
autoescape=select_autoescape(enabled_extensions=('html', 'xml')),
autoescape=select_autoescape(enabled_extensions=("html", "xml")),
)
with open(os.path.join(d, 'config.yml'), 'w') as f:
f.write(env.get_template('config.yml.in').render(
with open(os.path.join(d, "config.yml"), "w") as f:
f.write(
env.get_template("config.yml.in").render(
build_workflows=build_workflows,
unittest_workflows=unittest_workflows,
))
)
)
f.write("\n")
......@@ -19,7 +19,6 @@ import signal
import subprocess
import sys
import traceback
from functools import partial
try:
......@@ -28,7 +27,7 @@ except ImportError:
DEVNULL = open(os.devnull, "wb")
DEFAULT_EXTENSIONS = 'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu'
DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu"
class ExitStatus:
......@@ -52,14 +51,8 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
# os.walk() supports trimming down the dnames list
# by modifying it in-place,
# to avoid unnecessary directory listings.
dnames[:] = [
x for x in dnames
if
not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)
]
fpaths = [
x for x in fpaths if not fnmatch.fnmatch(x, pattern)
]
dnames[:] = [x for x in dnames if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)]
fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)]
for f in fpaths:
ext = os.path.splitext(f)[1][1:]
if ext in extensions:
......@@ -72,11 +65,9 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
def make_diff(file, original, reformatted):
return list(
difflib.unified_diff(
original,
reformatted,
fromfile='{}\t(original)'.format(file),
tofile='{}\t(reformatted)'.format(file),
n=3))
original, reformatted, fromfile="{}\t(original)".format(file), tofile="{}\t(reformatted)".format(file), n=3
)
)
class DiffError(Exception):
......@@ -99,13 +90,12 @@ def run_clang_format_diff_wrapper(args, file):
except DiffError:
raise
except Exception as e:
raise UnexpectedError('{}: {}: {}'.format(file, e.__class__.__name__,
e), e)
raise UnexpectedError("{}: {}: {}".format(file, e.__class__.__name__, e), e)
def run_clang_format_diff(args, file):
try:
with io.open(file, 'r', encoding='utf-8') as f:
with io.open(file, "r", encoding="utf-8") as f:
original = f.readlines()
except IOError as exc:
raise DiffError(str(exc))
......@@ -130,17 +120,10 @@ def run_clang_format_diff(args, file):
try:
proc = subprocess.Popen(
invocation,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
encoding='utf-8')
except OSError as exc:
raise DiffError(
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(invocation), exc
)
invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, encoding="utf-8"
)
except OSError as exc:
raise DiffError("Command '{}' failed to start: {}".format(subprocess.list2cmdline(invocation), exc))
proc_stdout = proc.stdout
proc_stderr = proc.stderr
......@@ -159,30 +142,30 @@ def run_clang_format_diff(args, file):
def bold_red(s):
return '\x1b[1m\x1b[31m' + s + '\x1b[0m'
return "\x1b[1m\x1b[31m" + s + "\x1b[0m"
def colorize(diff_lines):
def bold(s):
return '\x1b[1m' + s + '\x1b[0m'
return "\x1b[1m" + s + "\x1b[0m"
def cyan(s):
return '\x1b[36m' + s + '\x1b[0m'
return "\x1b[36m" + s + "\x1b[0m"
def green(s):
return '\x1b[32m' + s + '\x1b[0m'
return "\x1b[32m" + s + "\x1b[0m"
def red(s):
return '\x1b[31m' + s + '\x1b[0m'
return "\x1b[31m" + s + "\x1b[0m"
for line in diff_lines:
if line[:4] in ['--- ', '+++ ']:
if line[:4] in ["--- ", "+++ "]:
yield bold(line)
elif line.startswith('@@ '):
elif line.startswith("@@ "):
yield cyan(line)
elif line.startswith('+'):
elif line.startswith("+"):
yield green(line)
elif line.startswith('-'):
elif line.startswith("-"):
yield red(line)
else:
yield line
......@@ -195,7 +178,7 @@ def print_diff(diff_lines, use_color):
def print_trouble(prog, message, use_colors):
error_text = 'error:'
error_text = "error:"
if use_colors:
error_text = bold_red(error_text)
print("{}: {} {}".format(prog, error_text, message), file=sys.stderr)
......@@ -204,45 +187,37 @@ def print_trouble(prog, message, use_colors):
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--clang-format-executable',
metavar='EXECUTABLE',
help='path to the clang-format executable',
default='clang-format')
parser.add_argument(
'--extensions',
help='comma separated list of file extensions (default: {})'.format(
DEFAULT_EXTENSIONS),
default=DEFAULT_EXTENSIONS)
parser.add_argument(
'-r',
'--recursive',
action='store_true',
help='run recursively over directories')
parser.add_argument('files', metavar='file', nargs='+')
"--clang-format-executable",
metavar="EXECUTABLE",
help="path to the clang-format executable",
default="clang-format",
)
parser.add_argument(
'-q',
'--quiet',
action='store_true')
"--extensions",
help="comma separated list of file extensions (default: {})".format(DEFAULT_EXTENSIONS),
default=DEFAULT_EXTENSIONS,
)
parser.add_argument("-r", "--recursive", action="store_true", help="run recursively over directories")
parser.add_argument("files", metavar="file", nargs="+")
parser.add_argument("-q", "--quiet", action="store_true")
parser.add_argument(
'-j',
metavar='N',
"-j",
metavar="N",
type=int,
default=0,
help='run N clang-format jobs in parallel'
' (default number of cpus + 1)')
help="run N clang-format jobs in parallel" " (default number of cpus + 1)",
)
parser.add_argument(
'--color',
default='auto',
choices=['auto', 'always', 'never'],
help='show colored diff (default: auto)')
"--color", default="auto", choices=["auto", "always", "never"], help="show colored diff (default: auto)"
)
parser.add_argument(
'-e',
'--exclude',
metavar='PATTERN',
action='append',
"-e",
"--exclude",
metavar="PATTERN",
action="append",
default=[],
help='exclude paths matching the given glob-like pattern(s)'
' from recursive search')
help="exclude paths matching the given glob-like pattern(s)" " from recursive search",
)
args = parser.parse_args()
......@@ -259,10 +234,10 @@ def main():
colored_stdout = False
colored_stderr = False
if args.color == 'always':
if args.color == "always":
colored_stdout = True
colored_stderr = True
elif args.color == 'auto':
elif args.color == "auto":
colored_stdout = sys.stdout.isatty()
colored_stderr = sys.stderr.isatty()
......@@ -275,19 +250,15 @@ def main():
except OSError as e:
print_trouble(
parser.prog,
"Command '{}' failed to start: {}".format(
subprocess.list2cmdline(version_invocation), e
),
"Command '{}' failed to start: {}".format(subprocess.list2cmdline(version_invocation), e),
use_colors=colored_stderr,
)
return ExitStatus.TROUBLE
retcode = ExitStatus.SUCCESS
files = list_files(
args.files,
recursive=args.recursive,
exclude=args.exclude,
extensions=args.extensions.split(','))
args.files, recursive=args.recursive, exclude=args.exclude, extensions=args.extensions.split(",")
)
if not files:
return
......@@ -304,8 +275,7 @@ def main():
pool = None
else:
pool = multiprocessing.Pool(njobs)
it = pool.imap_unordered(
partial(run_clang_format_diff_wrapper, args), files)
it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files)
while True:
try:
outs, errs = next(it)
......@@ -336,5 +306,5 @@ def main():
return retcode
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())
......@@ -45,13 +45,13 @@ def query_torchaudio(cmd: str, *, accept) -> Any:
def get_pr_merger_and_number(commit_hash: str) -> Optional[str]:
data = query_torchaudio(f"commits/{commit_hash}", accept="application/vnd.github.v3+json")
commit_message = data['commit']['message']
commit_message = data["commit"]["message"]
pulled_by = commit_message.split('Pulled By: ')
pulled_by = pulled_by[1].split('\n')[0] if len(pulled_by) > 1 else None
pulled_by = commit_message.split("Pulled By: ")
pulled_by = pulled_by[1].split("\n")[0] if len(pulled_by) > 1 else None
pr_number = commit_message.split('Pull Request resolved: https://github.com/pytorch/audio/pull/')
pr_number = pr_number[1].split('\n')[0] if len(pr_number) > 1 else None
pr_number = commit_message.split("Pull Request resolved: https://github.com/pytorch/audio/pull/")
pr_number = pr_number[1].split("\n")[0] if len(pr_number) > 1 else None
return pulled_by, pr_number
......
......@@ -22,97 +22,98 @@
# sys.path.insert(0, os.path.abspath('.'))
import os
import re
import pytorch_sphinx_theme
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
needs_sphinx = '1.6'
needs_sphinx = "1.6"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinxcontrib.katex',
'sphinxcontrib.bibtex',
'sphinx_gallery.gen_gallery',
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx.ext.coverage",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinxcontrib.katex",
"sphinxcontrib.bibtex",
"sphinx_gallery.gen_gallery",
]
# katex options
#
#
katex_options = r'''
katex_options = r"""
delimiters : [
{left: "$$", right: "$$", display: true},
{left: "\\(", right: "\\)", display: false},
{left: "\\[", right: "\\]", display: true}
]
'''
"""
bibtex_bibfiles = ['refs.bib']
bibtex_bibfiles = ["refs.bib"]
def _get_var(var, default=False):
if var not in os.environ:
return default
val = os.environ.get(var, '0')
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
val = os.environ.get(var, "0")
trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"]
falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"]
if val in trues:
return True
if val not in falses:
print(
f' --- WARNING: Unexpected environment variable value `{var}={val}`. '
f'Expected one of {trues + falses}')
f" --- WARNING: Unexpected environment variable value `{var}={val}`. " f"Expected one of {trues + falses}"
)
return False
def _get_pattern():
pattern = os.getenv('GALLERY_PATTERN')
pattern = os.getenv("GALLERY_PATTERN")
# If BUILD_GALLERY is falsy -> no build
# If BUILD_GALLERY is truey -> build
# If BUILD_GALLERY is undefined
# If GALLERY_PATTERN is defined -> build
# If GALLERY_PATTERN is not defined -> not build
if not _get_var('BUILD_GALLERY', default=False if pattern is None else True):
if not _get_var("BUILD_GALLERY", default=False if pattern is None else True):
if pattern is not None:
print(
' --- WARNING: "GALLERY_PATTERN" is provided, but "BUILD_GALLERY" value is falsy. '
'Sphinx galleries are not built. To build galleries, set `BUILD_GALLERY=1`.'
"Sphinx galleries are not built. To build galleries, set `BUILD_GALLERY=1`."
)
return {
'ignore_pattern': r'\.py',
"ignore_pattern": r"\.py",
}
ret = {'filename_pattern': 'tutorial.py'}
if os.getenv('GALLERY_PATTERN'):
ret = {"filename_pattern": "tutorial.py"}
if os.getenv("GALLERY_PATTERN"):
# See https://github.com/pytorch/tutorials/blob/cbf2238df0e78d84c15bd94288966d2f4b2e83ae/conf.py#L75-L83
ret['ignore_pattern'] = r'/(?!' + re.escape(os.getenv('GALLERY_PATTERN')) + r')[^/]+$'
ret["ignore_pattern"] = r"/(?!" + re.escape(os.getenv("GALLERY_PATTERN")) + r")[^/]+$"
return ret
sphinx_gallery_conf = {
'examples_dirs': [
'../../examples/tutorials',
"examples_dirs": [
"../../examples/tutorials",
],
'gallery_dirs': [
'tutorials',
"gallery_dirs": [
"tutorials",
],
**_get_pattern(),
'backreferences_dir': 'gen_modules/backreferences',
'first_notebook_cell': None,
'doc_module': ('torchaudio',),
"backreferences_dir": "gen_modules/backreferences",
"first_notebook_cell": None,
"doc_module": ("torchaudio",),
}
autosummary_generate = True
......@@ -121,21 +122,21 @@ napoleon_numpy_docstring = False
napoleon_google_docstring = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
source_suffix = ".rst"
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# General information about the project.
project = 'Torchaudio'
copyright = '2018, Torchaudio Contributors'
author = 'Torchaudio Contributors'
project = "Torchaudio"
copyright = "2018, Torchaudio Contributors"
author = "Torchaudio Contributors"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
......@@ -143,10 +144,10 @@ author = 'Torchaudio Contributors'
#
# The short X.Y version.
# TODO: change to [:2] at v1.0
version = 'main '
version = "main "
# The full version, including alpha/beta/rc tags.
# TODO: verify this works as expected
release = 'main'
release = "main"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
......@@ -158,10 +159,10 @@ language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['*/index.rst']
exclude_patterns = ["*/index.rst"]
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
......@@ -172,7 +173,7 @@ todo_include_todos = True
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'pytorch_sphinx_theme'
html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
......@@ -180,28 +181,26 @@ html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# documentation.
#
html_theme_options = {
'pytorch_project': 'audio',
'collapse_navigation': False,
'display_version': True,
'logo_only': True,
'navigation_with_keys': True,
'analytics_id': 'UA-117752657-2',
"pytorch_project": "audio",
"collapse_navigation": False,
"display_version": True,
"logo_only": True,
"navigation_with_keys": True,
"analytics_id": "UA-117752657-2",
}
html_logo = '_static/img/pytorch-logo-dark.svg'
html_logo = "_static/img/pytorch-logo-dark.svg"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_css_files = [
'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css'
]
html_static_path = ["_static"]
html_css_files = ["https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css"]
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'TorchAudiodoc'
htmlhelp_basename = "TorchAudiodoc"
# -- Options for LaTeX output ---------------------------------------------
......@@ -210,15 +209,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
......@@ -228,8 +224,7 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'pytorch.tex', 'Torchaudio Documentation',
'Torch Contributors', 'manual'),
(master_doc, "pytorch.tex", "Torchaudio Documentation", "Torch Contributors", "manual"),
]
......@@ -237,10 +232,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'Torchaudio', 'Torchaudio Documentation',
[author], 1)
]
man_pages = [(master_doc, "Torchaudio", "Torchaudio Documentation", [author], 1)]
# -- Options for Texinfo output -------------------------------------------
......@@ -249,25 +241,31 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'Torchaudio', 'Torchaudio Documentation',
author, 'Torchaudio', 'Load audio files into pytorch tensors.',
'Miscellaneous'),
(
master_doc,
"Torchaudio",
"Torchaudio Documentation",
author,
"Torchaudio",
"Load audio files into pytorch tensors.",
"Miscellaneous",
),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/3/', None),
'numpy': ('https://numpy.org/doc/stable/', None),
'torch': ('https://pytorch.org/docs/stable/', None),
"python": ("https://docs.python.org/3/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
}
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
# See http://stackoverflow.com/a/41184353/3343043
from docutils import nodes
from sphinx.util.docfields import TypedField
from sphinx import addnodes
from sphinx.util.docfields import TypedField
def patched_make_field(self, types, domain, items, **kw):
......@@ -277,39 +275,39 @@ def patched_make_field(self, types, domain, items, **kw):
# type: (list, str, tuple) -> nodes.field
def handle_item(fieldarg, content):
par = nodes.paragraph()
par += addnodes.literal_strong('', fieldarg) # Patch: this line added
par += addnodes.literal_strong("", fieldarg) # Patch: this line added
# par.extend(self.make_xrefs(self.rolename, domain, fieldarg,
# addnodes.literal_strong))
if fieldarg in types:
par += nodes.Text(' (')
par += nodes.Text(" (")
# NOTE: using .pop() here to prevent a single type node to be
# inserted twice into the doctree, which leads to
# inconsistencies later when references are resolved
fieldtype = types.pop(fieldarg)
if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text):
typename = u''.join(n.astext() for n in fieldtype)
typename = typename.replace('int', 'python:int')
typename = typename.replace('long', 'python:long')
typename = typename.replace('float', 'python:float')
typename = typename.replace('type', 'python:type')
par.extend(self.make_xrefs(self.typerolename, domain, typename,
addnodes.literal_emphasis, **kw))
typename = "".join(n.astext() for n in fieldtype)
typename = typename.replace("int", "python:int")
typename = typename.replace("long", "python:long")
typename = typename.replace("float", "python:float")
typename = typename.replace("type", "python:type")
par.extend(self.make_xrefs(self.typerolename, domain, typename, addnodes.literal_emphasis, **kw))
else:
par += fieldtype
par += nodes.Text(')')
par += nodes.Text(' -- ')
par += nodes.Text(")")
par += nodes.Text(" -- ")
par += content
return par
fieldname = nodes.field_name('', self.label)
fieldname = nodes.field_name("", self.label)
if len(items) == 1 and self.can_collapse:
fieldarg, content = items[0]
bodynode = handle_item(fieldarg, content)
else:
bodynode = self.list_type()
for fieldarg, content in items:
bodynode += nodes.list_item('', handle_item(fieldarg, content))
fieldbody = nodes.field_body('', bodynode)
return nodes.field('', fieldname, fieldbody)
bodynode += nodes.list_item("", handle_item(fieldarg, content))
fieldbody = nodes.field_body("", bodynode)
return nodes.field("", fieldname, fieldbody)
TypedField.make_field = patched_make_field
from argparse import ArgumentParser
import logging
import pathlib
from argparse import ArgumentParser
import torch
import torchaudio
from lightning import RNNTModule
......@@ -12,9 +11,7 @@ logger = logging.getLogger()
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(
seq1.lower().split(), seq2.lower().split()
)
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def run_eval(args):
......@@ -38,9 +35,7 @@ def run_eval(args):
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.info(
f"Processed elem {idx}; WER: {total_edit_distance / total_length}"
)
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.info(f"Final WER: {total_edit_distance / total_length}")
......@@ -58,13 +53,20 @@ def cli_main():
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech_path", type=pathlib.Path, help="Path to LibriSpeech datasets.",
"--librispeech_path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp_model_path", type=pathlib.Path, help="Path to SentencePiece model.",
"--sp_model_path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--use_cuda", action="store_true", default=False, help="Run using CUDA.",
"--use_cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
args = parser.parse_args()
run_eval(args)
......
from collections import namedtuple
import json
import math
import os
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
import torchaudio.functional as F
from pytorch_lightning import LightningModule
from torchaudio.prototype.rnnt import emformer_rnnt_base
from torchaudio.prototype.rnnt_decoder import Hypothesis, RNNTBeamSearch
from pytorch_lightning import LightningModule
Batch = namedtuple(
"Batch", ["features", "feature_lengths", "targets", "target_lengths"]
)
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, n_fft=400, n_mels=80, hop_length=160
)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
def _batch_by_token_count(idx_target_lengths, token_limit):
......@@ -61,9 +56,7 @@ class CustomDataset(torch.utils.data.Dataset):
assert len(idx_target_lengths) > 0
idx_target_lengths = sorted(
idx_target_lengths, key=lambda x: x[1], reverse=True
)
idx_target_lengths = sorted(idx_target_lengths, key=lambda x: x[1], reverse=True)
assert max_token_limit >= idx_target_lengths[0][1]
......@@ -74,9 +67,7 @@ class CustomDataset(torch.utils.data.Dataset):
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + self.base_dataset._ext_txt
file_text = os.path.join(
self.base_dataset._path, speaker_id, chapter_id, file_text
)
file_text = os.path.join(self.base_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
......@@ -93,24 +84,16 @@ class CustomDataset(torch.utils.data.Dataset):
class TimeMasking(torchaudio.transforms._AxisMasking):
def __init__(
self, time_mask_param: int, min_mask_p: float, iid_masks: bool = False
) -> None:
def __init__(self, time_mask_param: int, min_mask_p: float, iid_masks: bool = False) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
self.min_mask_p = min_mask_p
def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:
if self.iid_masks and specgram.dim() == 4:
mask_param = min(
self.mask_param, self.min_mask_p * specgram.shape[self.axis + 1]
)
return F.mask_along_axis_iid(
specgram, mask_param, mask_value, self.axis + 1
)
mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis + 1])
return F.mask_along_axis_iid(specgram, mask_param, mask_value, self.axis + 1)
else:
mask_param = min(
self.mask_param, self.min_mask_p * specgram.shape[self.axis]
)
mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis])
return F.mask_along_axis(specgram, mask_param, mask_value, self.axis)
......@@ -149,10 +132,7 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
return [
(min(1.0, self._step_count / self.warmup_updates)) * base_lr
for base_lr in self.base_lrs
]
return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]
def post_process_hypos(
......@@ -164,12 +144,7 @@ def post_process_hypos(
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[
token_index
for token_index in h.tokens[1:]
if token_index not in post_process_remove_list
]
for h in hypos
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
......@@ -193,12 +168,8 @@ class RNNTModule(LightningModule):
self.model = emformer_rnnt_base()
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0)
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, factor=0.96, patience=0
)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.96, patience=0)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential(
......@@ -236,27 +207,17 @@ class RNNTModule(LightningModule):
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [
_spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
for sample in samples
]
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor(
[elem.shape[0] for elem in mel_features], dtype=torch.int32
)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [
_spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
for sample in samples
]
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor(
[elem.shape[0] for elem in mel_features], dtype=torch.int32
)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _train_collate_fn(self, samples: List):
......@@ -276,9 +237,7 @@ class RNNTModule(LightningModule):
if batch is None:
return None
prepended_targets = batch.targets.new_empty(
[batch.targets.size(0), batch.targets.size(1) + 1]
)
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
......@@ -307,9 +266,7 @@ class RNNTModule(LightningModule):
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder(
batch.features.to(self.device), batch.feature_lengths.to(self.device), 20
)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
......@@ -325,21 +282,15 @@ class RNNTModule(LightningModule):
dataset = torch.utils.data.ConcatDataset(
[
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="train-clean-360"
),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
1000,
),
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="train-clean-100"
),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
1000,
),
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="train-other-500"
),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
1000,
),
]
......@@ -357,29 +308,24 @@ class RNNTModule(LightningModule):
dataset = torch.utils.data.ConcatDataset(
[
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="dev-clean"
),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
1000,
),
CustomDataset(
torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="dev-other"
),
torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
1000,
),
]
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=None, collate_fn=self._valid_collate_fn, num_workers=10,
dataset,
batch_size=None,
collate_fn=self._valid_collate_fn,
num_workers=10,
)
return dataloader
def test_dataloader(self):
dataset = torchaudio.datasets.LIBRISPEECH(
self.librispeech_path, url="test-clean"
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, collate_fn=self._test_collate_fn
)
dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
from argparse import ArgumentParser
import pathlib
from argparse import ArgumentParser
from lightning import RNNTModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from lightning import RNNTModule
def run_train(args):
checkpoint_dir = args.exp_dir / "checkpoints"
......@@ -63,10 +62,14 @@ def cli_main():
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech_path", type=pathlib.Path, help="Path to LibriSpeech datasets.",
"--librispeech_path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp_model_path", type=pathlib.Path, help="Path to SentencePiece model.",
"--sp_model_path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--num_nodes",
......
import random
from pathlib import Path
from typing import (
Dict,
......@@ -7,11 +8,11 @@ from typing import (
Tuple,
Union,
)
from torch import Tensor
import numpy as np
import random
import torch
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset, BatchSampler
......@@ -30,27 +31,22 @@ class BucketizeSampler(BatchSampler):
the lengths of samples are unknown, the batch size may be different for different
mini-batches.
"""
def __init__(
self,
data_source: Dataset,
num_buckets: int,
max_token_count: Optional[int] = None,
batch_size: Optional[int] = None
batch_size: Optional[int] = None,
) -> None:
if max_token_count is not None and batch_size is not None:
raise AssertionError(
"The ``max_token_count`` and ``batch_size`` can't be both set."
)
raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.")
self.data_source = data_source
self.max_token_count = max_token_count
self.batch_size = batch_size
self.buckets = self._get_buckets(self.data_source, num_buckets)
def _get_buckets(
self,
data_source: Dataset,
num_buckets: int
) -> Dict[int, Tensor]:
def _get_buckets(self, data_source: Dataset, num_buckets: int) -> Dict[int, Tensor]:
"""Generate buckets based on the dataset.
Args:
data_source (Dataset): The dataset object to bucketize.
......@@ -126,6 +122,7 @@ class HuBERTDataSet(Dataset):
min_sample (int): The minimum number of audio samples in the dataset. (Default: 32000)
max_sample (int): The maximum number of audio samples in the dataset. (Default: 250000)
"""
def __init__(
self,
exp_dir: Union[str, Path],
......@@ -137,13 +134,7 @@ class HuBERTDataSet(Dataset):
self.exp_dir = Path(exp_dir)
tsv_dir = self.exp_dir / "tsv"
label_dir = self.exp_dir / "label"
f_list, ind_list, len_list = self._get_lists(
tsv_dir,
dataset,
subset,
min_sample,
max_sample
)
f_list, ind_list, len_list = self._get_lists(tsv_dir, dataset, subset, min_sample, max_sample)
self.f_list, self.ind_list, self.len_list = f_list, ind_list, len_list
self.labels = self._load_labels(label_dir, dataset, subset)
......@@ -188,10 +179,7 @@ class HuBERTDataSet(Dataset):
len_list.append(ele[2])
return np.asarray(f_list), np.asarray(ind_list), np.asarray(len_list)
def _load_audio(
self,
index: int
) -> Tensor:
def _load_audio(self, index: int) -> Tensor:
"""Load waveform given the sample index of the dataset.
Args:
index (int): The sample index.
......@@ -204,12 +192,7 @@ class HuBERTDataSet(Dataset):
assert waveform.shape[1] == self.len_list[index]
return waveform
def _load_labels(
self,
label_dir: Path,
dataset: str,
subset: str
) -> np.array:
def _load_labels(self, label_dir: Path, dataset: str, subset: str) -> np.array:
"""Load all labels to memory into a numpy array.
Args:
label_dir (Path): The directory that contains the label file.
......@@ -245,6 +228,7 @@ class CollateFnHubert:
waveform and label is random if the length is longer than the minimum
length in the mini-batch.
"""
def __init__(
self,
feature_type: str,
......@@ -284,7 +268,7 @@ class CollateFnHubert:
data = torch.zeros(len(batch), audio_size)
for i in range(len(waveforms)):
data[i][0:waveforms[i].shape[1]] = waveforms[i][0]
data[i][0 : waveforms[i].shape[1]] = waveforms[i][0]
lengths = torch.tensor(lengths)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
return data, labels, lengths
......@@ -318,16 +302,10 @@ class CollateFnHubert:
diff = waveform.size(1) - audio_size
audio_start = torch.randint(diff, size=(1,)) if rand_crop else 0
label_start = torch.div(
audio_start - kernel_size * sample_rate,
stride * sample_rate,
rounding_mode='floor'
)
label_size = torch.div(
audio_size - kernel_size * sample_rate,
stride * sample_rate,
rounding_mode='floor'
audio_start - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor"
)
waveform = waveform[:, audio_start:audio_start + audio_size]
label = label[label_start:label_start + label_size]
label_size = torch.div(audio_size - kernel_size * sample_rate, stride * sample_rate, rounding_mode="floor")
waveform = waveform[:, audio_start : audio_start + audio_size]
label = label[label_start : label_start + label_size]
length = audio_size
return waveform, label, length
......@@ -21,9 +21,7 @@ from utils import (
def _init_logger(debug=False):
message_fmt = (
"%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
)
message_fmt = "%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
logging.basicConfig(
level=logging.DEBUG if debug else logging.INFO,
format=f"%(asctime)s: {message_fmt}",
......@@ -84,7 +82,8 @@ def main(args):
for split in ["train", "valid"]:
p = Pool(args.num_rank)
inputs = [(
inputs = [
(
tsv_dir / f"{args.dataset}_{split}.tsv",
feat_dir,
split,
......@@ -92,7 +91,8 @@ def main(args):
args.num_rank,
device,
args.feat_type,
16_000,)
16_000,
)
for rank in range(args.num_rank)
]
_ = p.starmap(dump_features, inputs)
......
......@@ -37,7 +37,7 @@ def create_tsv(
[``librispeech``, ``libri-light``]. (Default: ``librispeech``)
valid_percent (float, optional): The percentage of data for validation. (Default: 0.01)
seed (int): The seed for randomly selecting the validation files.
extension (str, optional): The extention of audio files. (Default: ``flac``)
extension (str, optional): The extension of audio files. (Default: ``flac``)
Returns:
None
......@@ -51,11 +51,7 @@ def create_tsv(
if not out_dir.exists():
out_dir.mkdir()
valid_f = (
open(out_dir / f"{dataset}_valid.tsv", "w")
if valid_percent > 0
else None
)
valid_f = open(out_dir / f"{dataset}_valid.tsv", "w") if valid_percent > 0 else None
search_pattern = ".*train.*"
with open(out_dir / f"{dataset}_train.tsv", "w") as train_f:
print(root_dir, file=train_f)
......@@ -67,20 +63,13 @@ def create_tsv(
if re.match(search_pattern, str(fname)):
frames = torchaudio.info(fname).num_frames
dest = train_f if torch.rand(1) > valid_percent else valid_f
print(
f"{fname.relative_to(root_dir)}\t{frames}", file=dest
)
print(f"{fname.relative_to(root_dir)}\t{frames}", file=dest)
if valid_f is not None:
valid_f.close()
_LG.info("Finished creating the file lists successfully")
def _get_feat_lens_paths(
feat_dir: Path,
split: str,
rank: int,
num_rank: int
) -> Tuple[Path, Path]:
def _get_feat_lens_paths(feat_dir: Path, split: str, rank: int, num_rank: int) -> Tuple[Path, Path]:
r"""Get the feature and lengths paths based on feature directory,
data split, rank, and number of ranks.
Args:
......@@ -99,9 +88,7 @@ def _get_feat_lens_paths(
return feat_path, len_path
def _get_model_path(
km_dir: Path
) -> Path:
def _get_model_path(km_dir: Path) -> Path:
r"""Get the file path of the KMeans clustering model
Args:
km_dir (Path): The directory to store the KMeans clustering model.
......
......@@ -19,11 +19,7 @@ from .common_utils import _get_feat_lens_paths
_LG = logging.getLogger(__name__)
def get_shard_range(
num_lines: int,
num_rank: int,
rank: int
) -> Tuple[int, int]:
def get_shard_range(num_lines: int, num_rank: int, rank: int) -> Tuple[int, int]:
r"""Get the range of indices for the current rank in multi-processing.
Args:
num_lines (int): The number of lines to process.
......@@ -39,10 +35,7 @@ def get_shard_range(
assert num_lines > 0, f"Found {num_lines} files, make sure you specify the correct root directory"
start = round(num_lines / num_rank * rank)
end = round(num_lines / num_rank * (rank + 1))
_LG.info(
f"rank {rank} of {num_rank}, process {end-start} "
f"({start}-{end}) out of {num_lines}"
)
_LG.info(f"rank {rank} of {num_rank}, process {end-start} " f"({start}-{end}) out of {num_lines}")
return start, end
......@@ -68,9 +61,7 @@ def extract_feature(
waveform = waveform[0].to(device)
if feature_type == "mfcc":
feature_extractor = torchaudio.transforms.MFCC(
sample_rate=sample_rate,
n_mfcc=13,
melkwargs={'n_fft': 400, 'hop_length': 160, 'center': False}
sample_rate=sample_rate, n_mfcc=13, melkwargs={"n_fft": 400, "hop_length": 160, "center": False}
).to(device)
mfccs = feature_extractor(waveform) # (freq, time)
# mfccs = torchaudio.compliance.kaldi.mfcc(
......
......@@ -126,11 +126,7 @@ class ApplyKmeans(object):
self.Cnorm = torch.from_numpy(self.Cnorm_np).to(device)
def __call__(self, x):
dist = (
x.pow(2).sum(1, keepdim=True)
- 2 * torch.matmul(x, self.C)
+ self.Cnorm
)
dist = x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
return dist.argmin(dim=1).cpu().numpy()
......@@ -171,7 +167,7 @@ def get_km_label(
assert feats.shape[0] == lens.sum()
with open(label_path, "w") as f:
for i in range(lens.shape[0]):
feat = feats[offset:offset + lens[i]].to(device)
feat = feats[offset : offset + lens[i]].to(device)
offset += lens[i]
label = apply_kmeans(feat).tolist()
f.write(" ".join(map(str, label)) + "\n")
......
from . import utils, vad
__all__ = ['utils', 'vad']
__all__ = ["utils", "vad"]
......@@ -13,7 +13,6 @@ import datetime as dt
import logging
from fairseq import options
from interactive_asr.utils import add_asr_eval_argument, setup_asr, get_microphone_transcription, transcribe_file
......@@ -29,11 +28,7 @@ def main(args):
print("transcription_time:", transcription_time)
else:
for transcription in get_microphone_transcription(args, task, generator, models, sp, tgt_dict):
print(
"{}: {}".format(
dt.datetime.now().strftime("%H:%M:%S"), transcription[0][0]
)
)
print("{}: {}".format(dt.datetime.now().strftime("%H:%M:%S"), transcription[0][0]))
def cli_main():
......
......@@ -9,13 +9,11 @@ import os
import sys
import time
import sentencepiece as spm
import torch
import torchaudio
import sentencepiece as spm
from fairseq import tasks
from fairseq.utils import load_ensemble_for_inference, import_user_module
from interactive_asr.vad import get_microphone_chunks
......@@ -24,9 +22,7 @@ def add_asr_eval_argument(parser):
parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument("--wfstlm", default=None, help="wfstlm on dictonary output units")
parser.add_argument(
"--rnnt_decoding_type",
default="greedy",
......@@ -37,20 +33,14 @@ def add_asr_eval_argument(parser):
default=0.2,
help="weight for wfstlm while interpolating with neural score",
)
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
parser.add_argument("--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level")
return parser
def check_args(args):
assert args.path is not None, "--path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
assert not args.sampling or args.nbest == args.beam, "--sampling requires --nbest to be equal to --beam"
assert args.replace_unk is None or args.raw_text, "--replace-unk requires a raw text dataset (--raw-text)"
def process_predictions(args, hypos, sp, tgt_dict):
......@@ -64,8 +54,7 @@ def process_predictions(args, hypos, sp, tgt_dict):
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation
"""
"""Optimize ensemble for generation"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
......@@ -166,24 +155,16 @@ def transcribe_file(args, task, generator, models, sp, tgt_dict):
raise FileNotFoundError("Audio file not found: {}".format(path))
waveform, sample_rate = torchaudio.load_wav(path)
waveform = waveform.mean(0, True)
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
start = time.time()
transcription = transcribe(
waveform, args, task, generator, models, sp, tgt_dict
)
transcription = transcribe(waveform, args, task, generator, models, sp, tgt_dict)
transcription_time = time.time() - start
return transcription_time, transcription
def get_microphone_transcription(args, task, generator, models, sp, tgt_dict):
for (waveform, sample_rate) in get_microphone_chunks():
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform.reshape(1, -1))
transcription = transcribe(
waveform, args, task, generator, models, sp, tgt_dict
)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform.reshape(1, -1))
transcription = transcribe(waveform, args, task, generator, models, sp, tgt_dict)
yield transcription
......@@ -17,14 +17,13 @@ speech sequences. In the online case here, inertia is added before switching
from speech to silence or vice versa.
"""
from collections import deque
import numpy as np
import torch
import queue
from collections import deque
import librosa
import numpy as np
import pyaudio
import torch
import torchaudio
......@@ -95,9 +94,7 @@ class VoiceActivityDetection:
elif self.n < self.num_init_frames:
self.min_energy = min(energy, self.min_energy)
self.min_frequency = min(frequency, self.min_frequency)
self.min_spectral_flatness = min(
spectral_flatness, self.min_spectral_flatness
)
self.min_spectral_flatness = min(spectral_flatness, self.min_spectral_flatness)
self.n += 1
......@@ -121,10 +118,7 @@ class VoiceActivityDetection:
# Speech detected
self.speech_count += 1
# Inertia against switching
if (
self.n >= self.num_init_frames
and self.speech_count <= self.ignore_speech_count
):
if self.n >= self.num_init_frames and self.speech_count <= self.ignore_speech_count:
# Too soon to change
return self.silence_mark
else:
......@@ -132,15 +126,10 @@ class VoiceActivityDetection:
return self.speech_mark
else:
# Silence detected
self.min_energy = ((self.silent_count * self.min_energy) + energy) / (
self.silent_count + 1
)
self.min_energy = ((self.silent_count * self.min_energy) + energy) / (self.silent_count + 1)
self.silent_count += 1
# Inertia against switching
if (
self.n >= self.num_init_frames
and self.silent_count <= self.ignore_silent_count
):
if self.n >= self.num_init_frames and self.silent_count <= self.ignore_silent_count:
# Too soon to change
return self.speech_mark
else:
......@@ -260,9 +249,7 @@ def get_microphone_chunks(
else:
precumulated.append(chunk)
if (not is_speech and len(cumulated) >= min_to_cumulate) or (
len(cumulated) > max_to_cumulate
):
if (not is_speech and len(cumulated) >= min_to_cumulate) or (len(cumulated) > max_to_cumulate):
waveform = torch.cat(list(precumulated) + cumulated, -1)
yield (waveform * stream._rate, stream._rate)
cumulated = []
......
......@@ -2,8 +2,8 @@
"""
Create a data preprocess pipeline that can be run with libtorchaudio
"""
import os
import argparse
import os
import torch
import torchaudio
......@@ -14,10 +14,11 @@ class Pipeline(torch.nn.Module):
This example load waveform from a file then apply effects and save it to a file.
"""
def __init__(self, rir_path: str):
super().__init__()
rir, sample_rate = torchaudio.load(rir_path)
self.register_buffer('rir', rir)
self.register_buffer("rir", rir)
self.rir_sample_rate: int = sample_rate
def forward(self, input_path: str, output_path: str):
......@@ -32,7 +33,8 @@ class Pipeline(torch.nn.Module):
# 3. Reample the RIR filter to much the audio sample rate
rir, _ = torchaudio.sox_effects.apply_effects_tensor(
self.rir, self.rir_sample_rate, effects=[["rate", str(sample_rate)]])
self.rir, self.rir_sample_rate, effects=[["rate", str(sample_rate)]]
)
rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1])
......@@ -62,15 +64,9 @@ def _get_path(*paths):
def _parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--rir-path",
default=_get_path("..", "data", "rir.wav"),
help="Audio dara for room impulse response."
)
parser.add_argument(
"--output-path",
default=_get_path("pipeline.zip"),
help="Output JIT file."
"--rir-path", default=_get_path("..", "data", "rir.wav"), help="Audio dara for room impulse response."
)
parser.add_argument("--output-path", default=_get_path("pipeline.zip"), help="Output JIT file.")
return parser.parse_args()
......@@ -79,5 +75,5 @@ def _main():
_create_jit_pipeline(args.rir_path, args.output_path)
if __name__ == '__main__':
if __name__ == "__main__":
_main()
......@@ -3,18 +3,17 @@
To use this script, you need `fairseq`.
"""
import os
import argparse
import logging
import os
from typing import Tuple
import fairseq
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
import torchaudio
from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model
import fairseq
from greedy_decoder import Decoder
from torch.utils.mobile_optimizer import optimize_for_mobile
from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
......@@ -29,44 +28,31 @@ def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
)
parser.add_argument("--model-file", required=True, help="Path to the input pretrained weight file.")
parser.add_argument(
'--model-file',
required=True,
help='Path to the input pretrained weight file.'
)
parser.add_argument(
'--dict-dir',
"--dict-dir",
help=(
'Path to the directory in which `dict.ltr.txt` file is found. '
'Required only when the model is finetuned.'
)
"Path to the directory in which `dict.ltr.txt` file is found. " "Required only when the model is finetuned."
),
)
parser.add_argument(
'--output-path',
help='Path to the directory, where the TorchScript-ed pipelines are saved.',
"--output-path",
help="Path to the directory, where the TorchScript-ed pipelines are saved.",
)
parser.add_argument(
'--test-file',
help='Path to a test audio file.',
"--test-file",
help="Path to a test audio file.",
)
parser.add_argument(
'--debug',
action='store_true',
"--debug",
action="store_true",
help=(
'When enabled, individual components are separately tested '
'for the numerical compatibility and TorchScript compatibility.'
)
)
parser.add_argument(
'--quantize',
action='store_true',
help='Apply quantization to model.'
)
parser.add_argument(
'--optimize-for-mobile',
action='store_true',
help='Apply optmization for mobile.'
"When enabled, individual components are separately tested "
"for the numerical compatibility and TorchScript compatibility."
),
)
parser.add_argument("--quantize", action="store_true", help="Apply quantization to model.")
parser.add_argument("--optimize-for-mobile", action="store_true", help="Apply optmization for mobile.")
return parser.parse_args()
......@@ -74,7 +60,7 @@ class Loader(torch.nn.Module):
def forward(self, audio_path: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.)
waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.0)
return waveform
......@@ -129,11 +115,9 @@ def _get_decoder():
def _load_fairseq_model(input_file, data_dir=None):
overrides = {}
if data_dir:
overrides['data'] = data_dir
overrides["data"] = data_dir
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[input_file], arg_overrides=overrides
)
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([input_file], arg_overrides=overrides)
model = model[0]
return model
......@@ -154,36 +138,32 @@ def _main():
_LG.info(encoder)
if args.quantize:
_LG.info('Quantizing the model')
_LG.info("Quantizing the model")
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = tq.quantize_dynamic(
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
encoder = tq.quantize_dynamic(encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder)
# test
if args.test_file:
_LG.info('Testing with %s', args.test_file)
_LG.info("Testing with %s", args.test_file)
waveform = loader(args.test_file)
emission = encoder(waveform)
transcript = decoder(emission)
_LG.info(transcript)
torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip'))
torch.jit.script(decoder).save(os.path.join(args.output_path, 'decoder.zip'))
torch.jit.script(loader).save(os.path.join(args.output_path, "loader.zip"))
torch.jit.script(decoder).save(os.path.join(args.output_path, "decoder.zip"))
scripted = torch.jit.script(encoder)
if args.optimize_for_mobile:
scripted = optimize_for_mobile(scripted)
scripted.save(os.path.join(args.output_path, 'encoder.zip'))
scripted.save(os.path.join(args.output_path, "encoder.zip"))
def _init_logging(debug=False):
level = logging.DEBUG if debug else logging.INFO
format_ = (
'%(message)s' if not debug else
'%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s'
)
format_ = "%(message)s" if not debug else "%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s"
logging.basicConfig(level=level, format=format_)
if __name__ == '__main__':
if __name__ == "__main__":
_main()
......@@ -6,8 +6,8 @@ from typing import Tuple
import torch
import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
from greedy_decoder import Decoder
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
......@@ -22,31 +22,27 @@ def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
)
parser.add_argument("--model", required=True, help="Path to the input pretrained weight file.")
parser.add_argument(
'--model',
required=True,
help='Path to the input pretrained weight file.'
)
parser.add_argument(
'--output-path',
help='Path to the directory, where the Torchscript-ed pipelines are saved.',
"--output-path",
help="Path to the directory, where the Torchscript-ed pipelines are saved.",
)
parser.add_argument(
'--test-file',
help='Path to a test audio file.',
"--test-file",
help="Path to a test audio file.",
)
parser.add_argument(
'--quantize',
action='store_true',
help='Quantize the model.',
"--quantize",
action="store_true",
help="Quantize the model.",
)
parser.add_argument(
'--debug',
action='store_true',
"--debug",
action="store_true",
help=(
'When enabled, individual components are separately tested '
'for the numerical compatibility and TorchScript compatibility.'
)
"When enabled, individual components are separately tested "
"for the numerical compatibility and TorchScript compatibility."
),
)
return parser.parse_args()
......@@ -55,7 +51,7 @@ class Loader(torch.nn.Module):
def forward(self, audio_path: str) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.)
waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.0)
return waveform
......@@ -71,6 +67,7 @@ class Encoder(torch.nn.Module):
def _get_model(model_id):
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
tokenizer = Wav2Vec2Processor.from_pretrained(model_id).tokenizer
labels = [k for k, v in sorted(tokenizer.get_vocab().items(), key=lambda kv: kv[1])]
original = Wav2Vec2ForCTC.from_pretrained(model_id)
......@@ -85,43 +82,39 @@ def _get_decoder(labels):
def _main():
args = _parse_args()
_init_logging(args.debug)
_LG.info('Loading model: %s', args.model)
_LG.info("Loading model: %s", args.model)
model, labels = _get_model(args.model)
_LG.info('Labels: %s', labels)
_LG.info('Building pipeline')
_LG.info("Labels: %s", labels)
_LG.info("Building pipeline")
loader = Loader()
encoder = Encoder(model)
decoder = _get_decoder(labels)
_LG.info(encoder)
if args.quantize:
_LG.info('Quantizing the model')
_LG.info("Quantizing the model")
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
encoder = tq.quantize_dynamic(
encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
encoder = tq.quantize_dynamic(encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
_LG.info(encoder)
# test
if args.test_file:
_LG.info('Testing with %s', args.test_file)
_LG.info("Testing with %s", args.test_file)
waveform = loader(args.test_file)
emission = encoder(waveform)
transcript = decoder(emission)
_LG.info(transcript)
torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip'))
torch.jit.script(encoder).save(os.path.join(args.output_path, 'encoder.zip'))
torch.jit.script(decoder).save(os.path.join(args.output_path, 'decoder.zip'))
torch.jit.script(loader).save(os.path.join(args.output_path, "loader.zip"))
torch.jit.script(encoder).save(os.path.join(args.output_path, "encoder.zip"))
torch.jit.script(decoder).save(os.path.join(args.output_path, "decoder.zip"))
def _init_logging(debug=False):
level = logging.DEBUG if debug else logging.INFO
format_ = (
'%(message)s' if not debug else
'%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s'
)
format_ = "%(message)s" if not debug else "%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s"
logging.basicConfig(level=level, format=format_)
if __name__ == '__main__':
if __name__ == "__main__":
_main()
......@@ -17,12 +17,12 @@ class Decoder(torch.nn.Module):
"""
best_path = torch.argmax(logits, dim=-1) # [num_seq,]
best_path = torch.unique_consecutive(best_path, dim=-1)
hypothesis = ''
hypothesis = ""
for i in best_path:
char = self.labels[i]
if char in ['<s>', '<pad>']:
if char in ["<s>", "<pad>"]:
continue
if char == '|':
char = ' '
if char == "|":
char = " "
hypothesis += char
return hypothesis
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment