Commit 0e0d1e59 authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

docs update and fixes from pr comments

parent 8647f903
sphinx sphinx
sphinxcontrib-googleanalytics -e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
-e git://github.com/snide/sphinx_rtd_theme.git#egg=sphinx_rtd_theme sphinxcontrib.katex
matplotlib
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 21.0.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) --> <!-- Generator: Adobe Illustrator 22.1.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" <svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 199.7 40.2" style="enable-background:new 0 0 199.7 40.2;" xml:space="preserve"> viewBox="0 0 199.7 40.2" style="enable-background:new 0 0 199.7 40.2;" xml:space="preserve">
<style type="text/css"> <style type="text/css">
.st0{fill:#F05732;} .st0{fill:#EE4C2C;}
.st1{fill:#9E529F;} .st1{fill:#252525;}
.st2{fill:#333333;}
</style> </style>
<path class="st0" d="M102.7,12.2c-1.3-1-1.8,3.9-4.4,3.9c-3,0-4-13-6.3-13c-0.7,0-0.8-0.4-7.9,21.3c-2.9,9,4.4,15.8,11.8,15.8 <g>
c4.6,0,12.3-3,12.3-12.6C108.2,20.5,104.7,13.7,102.7,12.2z M95.8,35.3c-3.7,0-6.7-3.1-6.7-7c0-3.9,3-7,6.7-7s6.7,3.1,6.7,7 <path class="st0" d="M40.8,9.3l-2.1,2.1c3.5,3.5,3.5,9.2,0,12.7c-3.5,3.5-9.2,3.5-12.7,0c-3.5-3.5-3.5-9.2,0-12.7l0,0l5.6-5.6
C102.5,32.1,99.5,35.3,95.8,35.3z"/> L32.3,5l0,0V0.8l-8.5,8.5c-4.7,4.7-4.7,12.2,0,16.9s12.2,4.7,16.9,0C45.5,21.5,45.5,13.9,40.8,9.3z"/>
<path class="st1" d="M99.8,0c-0.5,0-1.8,2.5-1.8,3.6c0,1.5,1,2,1.8,2c0.8,0,1.8-0.5,1.8-2C101.5,2.5,100.2,0,99.8,0z"/> <circle class="st0" cx="36.6" cy="7.1" r="1.6"/>
<path class="st2" d="M0,39.5V14.9h11.5c5.3,0,8.3,3.6,8.3,7.9c0,4.3-3,7.9-8.3,7.9H5.2v8.8H0z M14.4,22.8c0-2.1-1.6-3.3-3.7-3.3H5.2 </g>
v6.6h5.5C12.8,26.1,14.4,24.8,14.4,22.8z"/> <g>
<path class="st2" d="M35.2,39.5V29.4l-9.4-14.5h6l6.1,9.8l6.1-9.8h5.9l-9.4,14.5v10.1H35.2z"/> <g>
<path class="st2" d="M63.3,39.5v-20h-7.2v-4.6h19.6v4.6h-7.2v20H63.3z"/> <path class="st1" d="M62.6,20l-3.6,0v9.3h-2.7V2.9c0,0,6.3,0,6.6,0c7,0,10.3,3.4,10.3,8.3C73.2,17,69.1,19.9,62.6,20z M62.8,5.4
<path class="st2" d="M131.4,39.5l-4.8-8.7h-3.8v8.7h-5.2V14.9H129c5.1,0,8.3,3.4,8.3,7.9c0,4.3-2.8,6.7-5.4,7.3l5.6,9.4H131.4z c-0.3,0-3.9,0-3.9,0v12.1l3.8-0.1c5-0.1,7.7-2.1,7.7-6.2C70.4,7.5,67.8,5.4,62.8,5.4z"/>
M131.9,22.8c0-2-1.6-3.3-3.7-3.3h-5.5v6.6h5.5C130.3,26.1,131.9,24.9,131.9,22.8z"/> <path class="st1" d="M85.4,29.2l-1.6,4.2c-1.8,4.7-3.6,6.1-6.3,6.1c-1.5,0-2.6-0.4-3.8-0.9l0.8-2.4c0.9,0.5,1.9,0.8,3,0.8
<path class="st2" d="M145.6,27.2c0-7.6,5.7-12.7,13.1-12.7c5.4,0,8.5,2.9,10.3,6l-4.5,2.2c-1-2-3.2-3.6-5.8-3.6 c1.5,0,2.6-0.8,4-4.5l1.3-3.4L75.3,10h2.8l6.1,16l6-16h2.7L85.4,29.2z"/>
c-4.5,0-7.7,3.4-7.7,8.1c0,4.6,3.2,8.1,7.7,8.1c2.5,0,4.7-1.6,5.8-3.6l4.5,2.2c-1.7,3.1-4.9,6-10.3,6 <path class="st1" d="M101.9,5.5v23.9h-2.7V5.5h-9.3V2.9h21.3v2.5H101.9z"/>
C151.3,39.9,145.6,34.7,145.6,27.2z"/> <path class="st1" d="M118.8,29.9c-5.4,0-9.4-4-9.4-10.2c0-6.2,4.1-10.3,9.6-10.3c5.4,0,9.3,4,9.3,10.2
<path class="st2" d="M194.5,39.5V29.1h-11.6v10.4h-5.2V14.9h5.2v9.7h11.6v-9.7h5.3v24.6H194.5z"/> C128.3,25.8,124.2,29.9,118.8,29.9z M118.9,11.8c-4.1,0-6.8,3.3-6.8,7.8c0,4.7,2.8,7.9,6.9,7.9s6.8-3.3,6.8-7.8
C125.8,15,123,11.8,118.9,11.8z"/>
<path class="st1" d="M135,29.4h-2.6V10l2.6-0.5v4.1c1.3-2.5,3.2-4.1,5.7-4.1c1.3,0,2.5,0.4,3.4,0.9l-0.7,2.5
c-0.8-0.5-1.9-0.8-3-0.8c-2,0-3.9,1.5-5.5,5V29.4z"/>
<path class="st1" d="M154.4,29.9c-5.8,0-9.5-4.2-9.5-10.2c0-6.1,4-10.3,9.5-10.3c2.4,0,4.4,0.6,6.1,1.7l-0.7,2.4
c-1.5-1-3.3-1.6-5.4-1.6c-4.2,0-6.8,3.1-6.8,7.7c0,4.7,2.8,7.8,6.9,7.8c1.9,0,3.9-0.6,5.4-1.6l0.5,2.4
C158.7,29.3,156.6,29.9,154.4,29.9z"/>
<path class="st1" d="M176.7,29.4V16.9c0-3.4-1.4-4.9-4.1-4.9c-2.2,0-4.4,1.1-6,2.8v14.7h-2.6V0.9l2.6-0.5c0,0,0,12.1,0,12.2
c2-2,4.6-3.1,6.7-3.1c3.8,0,6.1,2.4,6.1,6.6v13.3H176.7z"/>
</g>
</g>
</svg> </svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
height="40.200001"
width="40.200001"
xml:space="preserve"
viewBox="0 0 40.200002 40.2"
y="0px"
x="0px"
id="Layer_1"
version="1.1"><metadata
id="metadata4717"><rdf:RDF><cc:Work
rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
id="defs4715" /><style
id="style4694"
type="text/css">
.st0{fill:#F05732;}
.st1{fill:#9E529F;}
.st2{fill:#333333;}
</style><path
style="fill:#f05732"
id="path4696"
d="m 26.975479,12.199999 c -1.3,-1 -1.8,3.9 -4.4,3.9 -3,0 -4,-12.9999998 -6.3,-12.9999998 -0.7,0 -0.8,-0.4 -7.9000003,21.2999998 -2.9000001,9 4.4000003,15.8 11.8000003,15.8 4.6,0 12.3,-3 12.3,-12.6 0,-7.1 -3.5,-13.9 -5.5,-15.4 z m -6.9,23.1 c -3.7,0 -6.7,-3.1 -6.7,-7 0,-3.9 3,-7 6.7,-7 3.7,0 6.7,3.1 6.7,7 0,3.8 -3,7 -6.7,7 z"
class="st0" /><path
style="fill:#9e529f"
id="path4698"
d="m 24.075479,-7.6293945e-7 c -0.5,0 -1.8,2.49999996293945 -1.8,3.59999996293945 0,1.5 1,2 1.8,2 0.8,0 1.8,-0.5 1.8,-2 -0.1,-1.1 -1.4,-3.59999996293945 -1.8,-3.59999996293945 z"
class="st1" /></svg>
\ No newline at end of file
...@@ -22,14 +22,13 @@ ...@@ -22,14 +22,13 @@
# sys.path.insert(0, os.path.abspath('.')) # sys.path.insert(0, os.path.abspath('.'))
import torch import torch
import torchaudio import torchaudio
import sphinx_rtd_theme import pytorch_sphinx_theme
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here. # If your documentation needs a minimal Sphinx version, state it here.
# #
# needs_sphinx = '1.0' needs_sphinx = '1.6'
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
...@@ -41,16 +40,24 @@ extensions = [ ...@@ -41,16 +40,24 @@ extensions = [
'sphinx.ext.intersphinx', 'sphinx.ext.intersphinx',
'sphinx.ext.todo', 'sphinx.ext.todo',
'sphinx.ext.coverage', 'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
'sphinxcontrib.googleanalytics', 'sphinxcontrib.katex',
] ]
napoleon_use_ivar = True # katex options
#
#
katex_options = r'''
delimiters : [
{left: "$$", right: "$$", display: true},
{left: "\\(", right: "\\)", display: false},
{left: "\\[", right: "\\]", display: true}
]
'''
googleanalytics_id = 'UA-90545585-1' napoleon_use_ivar = True
googleanalytics_enabled = True
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']
...@@ -66,8 +73,8 @@ master_doc = 'index' ...@@ -66,8 +73,8 @@ master_doc = 'index'
# General information about the project. # General information about the project.
project = 'Torchaudio' project = 'Torchaudio'
copyright = '2017, Torch Contributors' copyright = '2018, Torchaudio Contributors'
author = 'Torch Contributors' author = 'Torchaudio Contributors'
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
...@@ -104,14 +111,15 @@ todo_include_todos = True ...@@ -104,14 +111,15 @@ todo_include_todos = True
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'sphinx_rtd_theme' html_theme = 'pytorch_sphinx_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
# documentation. # documentation.
# #
html_theme_options = { html_theme_options = {
'pytorch_project': 'audio',
'collapse_navigation': False, 'collapse_navigation': False,
'display_version': True, 'display_version': True,
'logo_only': True, 'logo_only': True,
...@@ -124,19 +132,25 @@ html_logo = '_static/img/pytorch-logo-dark.svg' ...@@ -124,19 +132,25 @@ html_logo = '_static/img/pytorch-logo-dark.svg'
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ['_static']
# html_style_path = 'css/pytorch_theme.css' def setup(app):
html_context = { # NOTE: in Sphinx 1.8+ `html_css_files` is an official configuration value
'css_files': [ # and can be moved outside of this function (and the setup(app) function
'https://fonts.googleapis.com/css?family=Lato', # can be deleted).
'_static/css/pytorch_theme.css' html_css_files = [
], 'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css'
} ]
# In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is
# `add_stylesheet` (deprecated in 1.8).
add_css = getattr(app, 'add_css_file', getattr(app, 'add_stylesheet'))
for css_file in html_css_files:
add_css(css_file)
# -- Options for HTMLHelp output ------------------------------------------ # -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc' htmlhelp_basename = 'TorchAudiodoc'
# -- Options for LaTeX output --------------------------------------------- # -- Options for LaTeX output ---------------------------------------------
...@@ -163,7 +177,7 @@ latex_elements = { ...@@ -163,7 +177,7 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'pytorch.tex', 'torchaudio Documentation', (master_doc, 'pytorch.tex', 'Torchaudio Documentation',
'Torch Contributors', 'manual'), 'Torch Contributors', 'manual'),
] ]
...@@ -173,7 +187,7 @@ latex_documents = [ ...@@ -173,7 +187,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [
(master_doc, 'torchaudio', 'torchaudio Documentation', (master_doc, 'Torchaudio', 'Torchaudio Documentation',
[author], 1) [author], 1)
] ]
...@@ -184,8 +198,8 @@ man_pages = [ ...@@ -184,8 +198,8 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'torchaudio', 'torchaudio Documentation', (master_doc, 'Torchaudio', 'Torchaudio Documentation',
author, 'torchaudio', 'One line description of project.', author, 'Torchaudio', 'Load audio files into pytorch tensors.',
'Miscellaneous'), 'Miscellaneous'),
] ]
...@@ -193,7 +207,7 @@ texinfo_documents = [ ...@@ -193,7 +207,7 @@ texinfo_documents = [
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = { intersphinx_mapping = {
'python': ('https://docs.python.org/', None), 'python': ('https://docs.python.org/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None), 'numpy': ('https://docs.scipy.org/doc/numpy/', None),
} }
# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # -- A patch that prevents Sphinx from cross-referencing ivar tags -------
...@@ -246,5 +260,4 @@ def patched_make_field(self, types, domain, items, **kw): ...@@ -246,5 +260,4 @@ def patched_make_field(self, types, domain, items, **kw):
fieldbody = nodes.field_body('', bodynode) fieldbody = nodes.field_body('', bodynode)
return nodes.field('', fieldname, fieldbody) return nodes.field('', fieldname, fieldbody)
TypedField.make_field = patched_make_field TypedField.make_field = patched_make_field
...@@ -18,7 +18,7 @@ class TORCHAUDIODS(Dataset): ...@@ -18,7 +18,7 @@ class TORCHAUDIODS(Dataset):
self.si.precision = 16 self.si.precision = 16
self.E = torchaudio.sox_effects.SoxEffectsChain() self.E = torchaudio.sox_effects.SoxEffectsChain()
self.E.append_effect_to_chain("rate", [self.si.rate]) # resample to 16000hz self.E.append_effect_to_chain("rate", [self.si.rate]) # resample to 16000hz
self.E.append_effect_to_chain("channels", [self.si.channels]) # mono singal self.E.append_effect_to_chain("channels", [self.si.channels]) # mono signal
self.E.append_effect_to_chain("trim", [0, "16000s"]) # first 16000 samples of audio self.E.append_effect_to_chain("trim", [0, "16000s"]) # first 16000 samples of audio
def __getitem__(self, index): def __getitem__(self, index):
...@@ -30,7 +30,7 @@ class TORCHAUDIODS(Dataset): ...@@ -30,7 +30,7 @@ class TORCHAUDIODS(Dataset):
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
class Test_LoadSave(unittest.TestCase): class Test_DataLoader(unittest.TestCase):
def test_1(self): def test_1(self):
expected_size = (2, 1, 16000) expected_size = (2, 1, 16000)
ds = TORCHAUDIODS() ds = TORCHAUDIODS()
......
...@@ -20,15 +20,13 @@ class Tester(unittest.TestCase): ...@@ -20,15 +20,13 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone() audio_orig = self.sig.clone()
result = transforms.Scale()(audio_orig) result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1., self.assertTrue(result.min() >= -1. and result.max() <= 1.)
print("min: {}, max: {}".format(result.min(), result.max())))
maxminmax = np.abs( maxminmax = np.abs(
[audio_orig.min(), audio_orig.max()]).max().astype(np.float) [audio_orig.min(), audio_orig.max()]).max().astype(np.float)
result = transforms.Scale(factor=maxminmax)(audio_orig) result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1., result.min() >= -1. and result.max() <= 1.)
print("min: {}, max: {}".format(result.min(), result.max())))
repr_test = transforms.Scale() repr_test = transforms.Scale()
repr_test.__repr__() repr_test.__repr__()
...@@ -39,21 +37,19 @@ class Tester(unittest.TestCase): ...@@ -39,21 +37,19 @@ class Tester(unittest.TestCase):
length_orig = audio_orig.size(0) length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2) length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new)(audio_orig) result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
self.assertTrue(result.size(0) == length_new, self.assertEqual(result.size(0), length_new)
print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0))))
audio_orig = self.sig.clone() audio_orig = self.sig.clone()
length_orig = audio_orig.size(0) length_orig = audio_orig.size(0)
length_new = int(length_orig * 0.8) length_new = int(length_orig * 0.8)
result = transforms.PadTrim(max_len=length_new)(audio_orig) result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
self.assertTrue(result.size(0) == length_new, self.assertEqual(result.size(0), length_new)
print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0))))
repr_test = transforms.PadTrim(max_len=length_new) repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
repr_test.__repr__() repr_test.__repr__()
def test_downmix_mono(self): def test_downmix_mono(self):
...@@ -67,11 +63,11 @@ class Tester(unittest.TestCase): ...@@ -67,11 +63,11 @@ class Tester(unittest.TestCase):
self.assertTrue(audio_Stereo.size(1) == 2) self.assertTrue(audio_Stereo.size(1) == 2)
result = transforms.DownmixMono()(audio_Stereo) result = transforms.DownmixMono(channels_first=False)(audio_Stereo)
self.assertTrue(result.size(1) == 1) self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono() repr_test = transforms.DownmixMono(channels_first=False)
repr_test.__repr__() repr_test.__repr__()
def test_lc2cl(self): def test_lc2cl(self):
...@@ -107,7 +103,7 @@ class Tester(unittest.TestCase): ...@@ -107,7 +103,7 @@ class Tester(unittest.TestCase):
[audio_orig.min(), audio_orig.max()]).max().astype(np.float) [audio_orig.min(), audio_orig.max()]).max().astype(np.float)
tset = (transforms.Scale(factor=maxminmax), tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new)) transforms.PadTrim(max_len=length_new, channels_first=False))
result = transforms.Compose(tset)(audio_orig) result = transforms.Compose(tset)(audio_orig)
self.assertTrue(np.abs([result.min(), result.max()]).max() == 1.) self.assertTrue(np.abs([result.min(), result.max()]).max() == 1.)
......
...@@ -34,17 +34,18 @@ def load(filepath, ...@@ -34,17 +34,18 @@ def load(filepath,
If `callable`, then the output is passed as a parameter If `callable`, then the output is passed as a parameter
to the given function, then the output is divided by to the given function, then the output is divided by
the result. the result.
channels_first (bool): Set channels first or length first in result. Default: ``True``
num_frames (int, optional): number of frames to load. 0 to load everything after the offset. num_frames (int, optional): number of frames to load. 0 to load everything after the offset.
offset (int, optional): number of frames from the start of the file to begin data loading. offset (int, optional): number of frames from the start of the file to begin data loading.
signalinfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the signalinfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determine audio type cannot be automatically determined
encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined audio type cannot be automatically determined
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
Returns: tuple(Tensor, int) Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels - Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels
- int: the sample-rate of the audio (as listed in the metadata of the file) - int: the sample rate of the audio (as listed in the metadata of the file)
Example:: Example::
...@@ -113,8 +114,9 @@ def save_encinfo(filepath, ...@@ -113,8 +114,9 @@ def save_encinfo(filepath,
filepath (string): path to audio file filepath (string): path to audio file
src (Tensor): an input 2D Tensor of shape `[C x L]` or `[L x C]` where L is src (Tensor): an input 2D Tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels the number of audio frames, C is the number of channels
channels_first (bool): Set channels first or length first in result. Default: ``True``
signalinfo (sox_signalinfo_t): a sox_signalinfo_t type, which could be helpful if the signalinfo (sox_signalinfo_t): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determine audio type cannot be automatically determined
encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined audio type cannot be automatically determined
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
......
...@@ -28,6 +28,48 @@ def SoxEffect(): ...@@ -28,6 +28,48 @@ def SoxEffect():
class SoxEffectsChain(object): class SoxEffectsChain(object):
"""SoX effects chain class. """SoX effects chain class.
Args:
normalization (bool, number, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes signed 32-bit audio), and normalizes to `[0, 1]`.
If `number`, then output is divided by that number
If `callable`, then the output is passed as a parameter
to the given function, then the output is divided by
the result.
channels_first (bool, optional): Set channels first or length first in result. Default: ``True``
out_siginfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined
out_encinfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels
- int: the sample rate of the audio (as listed in the metadata of the file)
Example::
class MyDataset(Dataset):
def __init__(self, audiodir_path):
self.data = [fn for fn in os.listdir(audiodir_path)]
self.E = torchaudio.sox_effects.SoxEffectsChain()
self.E.append_effect_to_chain("rate", [16000]) # resample to 16000hz
self.E.append_effect_to_chain("channels", ["1"]) # mono signal
def __getitem__(self, index):
fn = self.data[index]
self.E.set_input_file(fn)
x, sr = self.E.sox_build_flow_effects()
return x, sr
def __len__(self):
return len(self.data)
>>> torchaudio.initialize_sox()
>>> ds = MyDataset(path_to_audio_files)
>>> for sig, sr in ds:
>>> [do something here]
>>> torchaudio.shutdown_sox()
""" """
EFFECTS_AVAILABLE = set(effect_names()) EFFECTS_AVAILABLE = set(effect_names())
......
...@@ -20,12 +20,6 @@ def _check_is_variable(tensor): ...@@ -20,12 +20,6 @@ def _check_is_variable(tensor):
return tensor, is_variable return tensor, is_variable
def _tlog10(x):
"""Pytorch Log10
"""
return torch.log(x) / torch.log(x.new([10]))
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -92,29 +86,35 @@ class PadTrim(object): ...@@ -92,29 +86,35 @@ class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels) """Pad/Trim a 1d-Tensor (Signal or Labels)
Args: Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels) tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
max_len (int): Length to which the tensor will be padded max_len (int): Length to which the tensor will be padded
channels_first (bool): Pad for channels first tensors. Default: `True`
""" """
def __init__(self, max_len, fill_value=0): def __init__(self, max_len, fill_value=0, channels_first=True):
self.max_len = max_len self.max_len = max_len
self.fill_value = fill_value self.fill_value = fill_value
self.len_dim, self.ch_dim = int(channels_first), int(not channels_first)
def __call__(self, tensor): def __call__(self, tensor):
""" """
Returns: Returns:
Tensor: (max_len x Channels) Tensor: (c x Ln or (n x c)
""" """
if self.max_len > tensor.size(0): assert tensor.size(self.ch_dim) < 128, \
pad = torch.ones((self.max_len - tensor.size(0), "Too many channels ({}) detected, look at channels_first param.".format(tensor.size(self.ch_dim))
tensor.size(1))) * self.fill_value if self.max_len > tensor.size(self.len_dim):
pad = pad.type_as(tensor)
tensor = torch.cat((tensor, pad), dim=0) padding_size = [self.max_len - tensor.size(self.len_dim) if i == self.len_dim
elif self.max_len < tensor.size(0): else tensor.size(self.ch_dim)
tensor = tensor[:self.max_len, :] for i in range(2)]
pad = torch.empty(padding_size, dtype=tensor.dtype).fill_(self.fill_value)
tensor = torch.cat((tensor, pad), dim=self.len_dim)
elif self.max_len < tensor.size(self.len_dim):
tensor = tensor.narrow(self.len_dim, 0, self.max_len)
return tensor return tensor
def __repr__(self): def __repr__(self):
...@@ -122,25 +122,26 @@ class PadTrim(object): ...@@ -122,25 +122,26 @@ class PadTrim(object):
class DownmixMono(object): class DownmixMono(object):
"""Downmix any stereo signals to mono """Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with
the `channels` effect instead of this transformation.
Inputs: Inputs:
tensor (Tensor): Tensor of audio of size (Samples x Channels) tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
channels_first (bool): Downmix across channels dimension. Default: `True`
Returns: Returns:
tensor (Tensor) (Samples x 1): tensor (Tensor) (Samples x 1):
""" """
def __init__(self): def __init__(self, channels_first=None):
pass self.ch_dim = int(not channels_first)
def __call__(self, tensor): def __call__(self, tensor):
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)): if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
tensor = tensor.float() tensor = tensor.float()
if tensor.size(1) > 1: tensor = torch.mean(tensor, self.ch_dim, True)
tensor = torch.mean(tensor.float(), 1, True)
return tensor return tensor
def __repr__(self): def __repr__(self):
...@@ -148,8 +149,7 @@ class DownmixMono(object): ...@@ -148,8 +149,7 @@ class DownmixMono(object):
class LC2CL(object): class LC2CL(object):
"""Permute a 2d tensor from samples (Length) x Channels to Channels x """Permute a 2d tensor from samples (n x c) to (c x n)
samples (Length)
""" """
def __call__(self, tensor): def __call__(self, tensor):
...@@ -162,7 +162,6 @@ class LC2CL(object): ...@@ -162,7 +162,6 @@ class LC2CL(object):
tensor (Tensor): Tensor of audio signal with shape (CxL) tensor (Tensor): Tensor of audio signal with shape (CxL)
""" """
return tensor.transpose(0, 1).contiguous() return tensor.transpose(0, 1).contiguous()
def __repr__(self): def __repr__(self):
...@@ -292,7 +291,7 @@ class SPEC2DB(object): ...@@ -292,7 +291,7 @@ class SPEC2DB(object):
def __call__(self, spec): def __call__(self, spec):
spec, is_variable = _check_is_variable(spec) spec, is_variable = _check_is_variable(spec)
spec_db = self.multiplier * _tlog10(spec / spec.max()) # power -> dB spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
if self.top_db is not None: if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db])) spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
return spec_db if is_variable else spec_db.data return spec_db if is_variable else spec_db.data
...@@ -320,7 +319,6 @@ class MEL2(object): ...@@ -320,7 +319,6 @@ class MEL2(object):
Example: Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True) >>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> sig = transforms.LC2CL()(sig) # (n, c) -> (c, n)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m) >>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
""" """
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
...@@ -406,8 +404,8 @@ class MEL(object): ...@@ -406,8 +404,8 @@ class MEL(object):
class BLC2CBL(object): class BLC2CBL(object):
"""Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x """Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x samples (Length) Bands x Samples length
""" """
def __call__(self, tensor): def __call__(self, tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment