Commit 0e0d1e59 authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

docs update and fixes from pr comments

parent 8647f903
sphinx
sphinxcontrib-googleanalytics
-e git://github.com/snide/sphinx_rtd_theme.git#egg=sphinx_rtd_theme
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinxcontrib.katex
matplotlib
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 21.0.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<!-- Generator: Adobe Illustrator 22.1.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 199.7 40.2" style="enable-background:new 0 0 199.7 40.2;" xml:space="preserve">
<style type="text/css">
.st0{fill:#F05732;}
.st1{fill:#9E529F;}
.st2{fill:#333333;}
.st0{fill:#EE4C2C;}
.st1{fill:#252525;}
</style>
<path class="st0" d="M102.7,12.2c-1.3-1-1.8,3.9-4.4,3.9c-3,0-4-13-6.3-13c-0.7,0-0.8-0.4-7.9,21.3c-2.9,9,4.4,15.8,11.8,15.8
c4.6,0,12.3-3,12.3-12.6C108.2,20.5,104.7,13.7,102.7,12.2z M95.8,35.3c-3.7,0-6.7-3.1-6.7-7c0-3.9,3-7,6.7-7s6.7,3.1,6.7,7
C102.5,32.1,99.5,35.3,95.8,35.3z"/>
<path class="st1" d="M99.8,0c-0.5,0-1.8,2.5-1.8,3.6c0,1.5,1,2,1.8,2c0.8,0,1.8-0.5,1.8-2C101.5,2.5,100.2,0,99.8,0z"/>
<path class="st2" d="M0,39.5V14.9h11.5c5.3,0,8.3,3.6,8.3,7.9c0,4.3-3,7.9-8.3,7.9H5.2v8.8H0z M14.4,22.8c0-2.1-1.6-3.3-3.7-3.3H5.2
v6.6h5.5C12.8,26.1,14.4,24.8,14.4,22.8z"/>
<path class="st2" d="M35.2,39.5V29.4l-9.4-14.5h6l6.1,9.8l6.1-9.8h5.9l-9.4,14.5v10.1H35.2z"/>
<path class="st2" d="M63.3,39.5v-20h-7.2v-4.6h19.6v4.6h-7.2v20H63.3z"/>
<path class="st2" d="M131.4,39.5l-4.8-8.7h-3.8v8.7h-5.2V14.9H129c5.1,0,8.3,3.4,8.3,7.9c0,4.3-2.8,6.7-5.4,7.3l5.6,9.4H131.4z
M131.9,22.8c0-2-1.6-3.3-3.7-3.3h-5.5v6.6h5.5C130.3,26.1,131.9,24.9,131.9,22.8z"/>
<path class="st2" d="M145.6,27.2c0-7.6,5.7-12.7,13.1-12.7c5.4,0,8.5,2.9,10.3,6l-4.5,2.2c-1-2-3.2-3.6-5.8-3.6
c-4.5,0-7.7,3.4-7.7,8.1c0,4.6,3.2,8.1,7.7,8.1c2.5,0,4.7-1.6,5.8-3.6l4.5,2.2c-1.7,3.1-4.9,6-10.3,6
C151.3,39.9,145.6,34.7,145.6,27.2z"/>
<path class="st2" d="M194.5,39.5V29.1h-11.6v10.4h-5.2V14.9h5.2v9.7h11.6v-9.7h5.3v24.6H194.5z"/>
<g>
<path class="st0" d="M40.8,9.3l-2.1,2.1c3.5,3.5,3.5,9.2,0,12.7c-3.5,3.5-9.2,3.5-12.7,0c-3.5-3.5-3.5-9.2,0-12.7l0,0l5.6-5.6
L32.3,5l0,0V0.8l-8.5,8.5c-4.7,4.7-4.7,12.2,0,16.9s12.2,4.7,16.9,0C45.5,21.5,45.5,13.9,40.8,9.3z"/>
<circle class="st0" cx="36.6" cy="7.1" r="1.6"/>
</g>
<g>
<g>
<path class="st1" d="M62.6,20l-3.6,0v9.3h-2.7V2.9c0,0,6.3,0,6.6,0c7,0,10.3,3.4,10.3,8.3C73.2,17,69.1,19.9,62.6,20z M62.8,5.4
c-0.3,0-3.9,0-3.9,0v12.1l3.8-0.1c5-0.1,7.7-2.1,7.7-6.2C70.4,7.5,67.8,5.4,62.8,5.4z"/>
<path class="st1" d="M85.4,29.2l-1.6,4.2c-1.8,4.7-3.6,6.1-6.3,6.1c-1.5,0-2.6-0.4-3.8-0.9l0.8-2.4c0.9,0.5,1.9,0.8,3,0.8
c1.5,0,2.6-0.8,4-4.5l1.3-3.4L75.3,10h2.8l6.1,16l6-16h2.7L85.4,29.2z"/>
<path class="st1" d="M101.9,5.5v23.9h-2.7V5.5h-9.3V2.9h21.3v2.5H101.9z"/>
<path class="st1" d="M118.8,29.9c-5.4,0-9.4-4-9.4-10.2c0-6.2,4.1-10.3,9.6-10.3c5.4,0,9.3,4,9.3,10.2
C128.3,25.8,124.2,29.9,118.8,29.9z M118.9,11.8c-4.1,0-6.8,3.3-6.8,7.8c0,4.7,2.8,7.9,6.9,7.9s6.8-3.3,6.8-7.8
C125.8,15,123,11.8,118.9,11.8z"/>
<path class="st1" d="M135,29.4h-2.6V10l2.6-0.5v4.1c1.3-2.5,3.2-4.1,5.7-4.1c1.3,0,2.5,0.4,3.4,0.9l-0.7,2.5
c-0.8-0.5-1.9-0.8-3-0.8c-2,0-3.9,1.5-5.5,5V29.4z"/>
<path class="st1" d="M154.4,29.9c-5.8,0-9.5-4.2-9.5-10.2c0-6.1,4-10.3,9.5-10.3c2.4,0,4.4,0.6,6.1,1.7l-0.7,2.4
c-1.5-1-3.3-1.6-5.4-1.6c-4.2,0-6.8,3.1-6.8,7.7c0,4.7,2.8,7.8,6.9,7.8c1.9,0,3.9-0.6,5.4-1.6l0.5,2.4
C158.7,29.3,156.6,29.9,154.4,29.9z"/>
<path class="st1" d="M176.7,29.4V16.9c0-3.4-1.4-4.9-4.1-4.9c-2.2,0-4.4,1.1-6,2.8v14.7h-2.6V0.9l2.6-0.5c0,0,0,12.1,0,12.2
c2-2,4.6-3.1,6.7-3.1c3.8,0,6.1,2.4,6.1,6.6v13.3H176.7z"/>
</g>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
height="40.200001"
width="40.200001"
xml:space="preserve"
viewBox="0 0 40.200002 40.2"
y="0px"
x="0px"
id="Layer_1"
version="1.1"><metadata
id="metadata4717"><rdf:RDF><cc:Work
rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
id="defs4715" /><style
id="style4694"
type="text/css">
.st0{fill:#F05732;}
.st1{fill:#9E529F;}
.st2{fill:#333333;}
</style><path
style="fill:#f05732"
id="path4696"
d="m 26.975479,12.199999 c -1.3,-1 -1.8,3.9 -4.4,3.9 -3,0 -4,-12.9999998 -6.3,-12.9999998 -0.7,0 -0.8,-0.4 -7.9000003,21.2999998 -2.9000001,9 4.4000003,15.8 11.8000003,15.8 4.6,0 12.3,-3 12.3,-12.6 0,-7.1 -3.5,-13.9 -5.5,-15.4 z m -6.9,23.1 c -3.7,0 -6.7,-3.1 -6.7,-7 0,-3.9 3,-7 6.7,-7 3.7,0 6.7,3.1 6.7,7 0,3.8 -3,7 -6.7,7 z"
class="st0" /><path
style="fill:#9e529f"
id="path4698"
d="m 24.075479,-7.6293945e-7 c -0.5,0 -1.8,2.49999996293945 -1.8,3.59999996293945 0,1.5 1,2 1.8,2 0.8,0 1.8,-0.5 1.8,-2 -0.1,-1.1 -1.4,-3.59999996293945 -1.8,-3.59999996293945 z"
class="st1" /></svg>
\ No newline at end of file
......@@ -22,14 +22,13 @@
# sys.path.insert(0, os.path.abspath('.'))
import torch
import torchaudio
import sphinx_rtd_theme
import pytorch_sphinx_theme
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
needs_sphinx = '1.6'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
......@@ -41,16 +40,24 @@ extensions = [
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinxcontrib.googleanalytics',
'sphinxcontrib.katex',
]
napoleon_use_ivar = True
# katex options
#
#
katex_options = r'''
delimiters : [
{left: "$$", right: "$$", display: true},
{left: "\\(", right: "\\)", display: false},
{left: "\\[", right: "\\]", display: true}
]
'''
googleanalytics_id = 'UA-90545585-1'
googleanalytics_enabled = True
napoleon_use_ivar = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
......@@ -66,8 +73,8 @@ master_doc = 'index'
# General information about the project.
project = 'Torchaudio'
copyright = '2017, Torch Contributors'
author = 'Torch Contributors'
copyright = '2018, Torchaudio Contributors'
author = 'Torchaudio Contributors'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
......@@ -104,14 +111,15 @@ todo_include_todos = True
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
html_theme = 'pytorch_sphinx_theme'
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
html_theme_options = {
'pytorch_project': 'audio',
'collapse_navigation': False,
'display_version': True,
'logo_only': True,
......@@ -124,19 +132,25 @@ html_logo = '_static/img/pytorch-logo-dark.svg'
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# html_style_path = 'css/pytorch_theme.css'
html_context = {
'css_files': [
'https://fonts.googleapis.com/css?family=Lato',
'_static/css/pytorch_theme.css'
],
}
def setup(app):
# NOTE: in Sphinx 1.8+ `html_css_files` is an official configuration value
# and can be moved outside of this function (and the setup(app) function
# can be deleted).
html_css_files = [
'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css'
]
# In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is
# `add_stylesheet` (deprecated in 1.8).
add_css = getattr(app, 'add_css_file', getattr(app, 'add_stylesheet'))
for css_file in html_css_files:
add_css(css_file)
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc'
htmlhelp_basename = 'TorchAudiodoc'
# -- Options for LaTeX output ---------------------------------------------
......@@ -163,7 +177,7 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'pytorch.tex', 'torchaudio Documentation',
(master_doc, 'pytorch.tex', 'Torchaudio Documentation',
'Torch Contributors', 'manual'),
]
......@@ -173,7 +187,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchaudio', 'torchaudio Documentation',
(master_doc, 'Torchaudio', 'Torchaudio Documentation',
[author], 1)
]
......@@ -184,8 +198,8 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'torchaudio', 'torchaudio Documentation',
author, 'torchaudio', 'One line description of project.',
(master_doc, 'Torchaudio', 'Torchaudio Documentation',
author, 'Torchaudio', 'Load audio files into pytorch tensors.',
'Miscellaneous'),
]
......@@ -193,7 +207,7 @@ texinfo_documents = [
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
}
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
......@@ -246,5 +260,4 @@ def patched_make_field(self, types, domain, items, **kw):
fieldbody = nodes.field_body('', bodynode)
return nodes.field('', fieldname, fieldbody)
TypedField.make_field = patched_make_field
......@@ -18,7 +18,7 @@ class TORCHAUDIODS(Dataset):
self.si.precision = 16
self.E = torchaudio.sox_effects.SoxEffectsChain()
self.E.append_effect_to_chain("rate", [self.si.rate]) # resample to 16000hz
self.E.append_effect_to_chain("channels", [self.si.channels]) # mono singal
self.E.append_effect_to_chain("channels", [self.si.channels]) # mono signal
self.E.append_effect_to_chain("trim", [0, "16000s"]) # first 16000 samples of audio
def __getitem__(self, index):
......@@ -30,7 +30,7 @@ class TORCHAUDIODS(Dataset):
def __len__(self):
return len(self.data)
class Test_LoadSave(unittest.TestCase):
class Test_DataLoader(unittest.TestCase):
def test_1(self):
expected_size = (2, 1, 16000)
ds = TORCHAUDIODS()
......
......@@ -20,15 +20,13 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone()
result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1.,
print("min: {}, max: {}".format(result.min(), result.max())))
self.assertTrue(result.min() >= -1. and result.max() <= 1.)
maxminmax = np.abs(
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1.,
print("min: {}, max: {}".format(result.min(), result.max())))
result.min() >= -1. and result.max() <= 1.)
repr_test = transforms.Scale()
repr_test.__repr__()
......@@ -39,21 +37,19 @@ class Tester(unittest.TestCase):
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new)(audio_orig)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
self.assertTrue(result.size(0) == length_new,
print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0))))
self.assertEqual(result.size(0), length_new)
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 0.8)
result = transforms.PadTrim(max_len=length_new)(audio_orig)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
self.assertTrue(result.size(0) == length_new,
print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0))))
self.assertEqual(result.size(0), length_new)
repr_test = transforms.PadTrim(max_len=length_new)
repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
repr_test.__repr__()
def test_downmix_mono(self):
......@@ -67,11 +63,11 @@ class Tester(unittest.TestCase):
self.assertTrue(audio_Stereo.size(1) == 2)
result = transforms.DownmixMono()(audio_Stereo)
result = transforms.DownmixMono(channels_first=False)(audio_Stereo)
self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono()
repr_test = transforms.DownmixMono(channels_first=False)
repr_test.__repr__()
def test_lc2cl(self):
......@@ -107,7 +103,7 @@ class Tester(unittest.TestCase):
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new))
transforms.PadTrim(max_len=length_new, channels_first=False))
result = transforms.Compose(tset)(audio_orig)
self.assertTrue(np.abs([result.min(), result.max()]).max() == 1.)
......
......@@ -34,17 +34,18 @@ def load(filepath,
If `callable`, then the output is passed as a parameter
to the given function, then the output is divided by
the result.
channels_first (bool): Set channels first or length first in result. Default: ``True``
num_frames (int, optional): number of frames to load. 0 to load everything after the offset.
offset (int, optional): number of frames from the start of the file to begin data loading.
signalinfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determine
audio type cannot be automatically determined
encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels
- int: the sample-rate of the audio (as listed in the metadata of the file)
- int: the sample rate of the audio (as listed in the metadata of the file)
Example::
......@@ -113,8 +114,9 @@ def save_encinfo(filepath,
filepath (string): path to audio file
src (Tensor): an input 2D Tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
channels_first (bool): Set channels first or length first in result. Default: ``True``
signalinfo (sox_signalinfo_t): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determine
audio type cannot be automatically determined
encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
......
......@@ -28,6 +28,48 @@ def SoxEffect():
class SoxEffectsChain(object):
"""SoX effects chain class.
Args:
normalization (bool, number, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes signed 32-bit audio), and normalizes to `[0, 1]`.
If `number`, then output is divided by that number
If `callable`, then the output is passed as a parameter
to the given function, then the output is divided by
the result.
channels_first (bool, optional): Set channels first or length first in result. Default: ``True``
out_siginfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined
out_encinfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels
- int: the sample rate of the audio (as listed in the metadata of the file)
Example::
class MyDataset(Dataset):
def __init__(self, audiodir_path):
self.data = [fn for fn in os.listdir(audiodir_path)]
self.E = torchaudio.sox_effects.SoxEffectsChain()
self.E.append_effect_to_chain("rate", [16000]) # resample to 16000hz
self.E.append_effect_to_chain("channels", ["1"]) # mono signal
def __getitem__(self, index):
fn = self.data[index]
self.E.set_input_file(fn)
x, sr = self.E.sox_build_flow_effects()
return x, sr
def __len__(self):
return len(self.data)
>>> torchaudio.initialize_sox()
>>> ds = MyDataset(path_to_audio_files)
>>> for sig, sr in ds:
>>> [do something here]
>>> torchaudio.shutdown_sox()
"""
EFFECTS_AVAILABLE = set(effect_names())
......
......@@ -20,12 +20,6 @@ def _check_is_variable(tensor):
return tensor, is_variable
def _tlog10(x):
"""Pytorch Log10
"""
return torch.log(x) / torch.log(x.new([10]))
class Compose(object):
"""Composes several transforms together.
......@@ -92,29 +86,35 @@ class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels)
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
max_len (int): Length to which the tensor will be padded
channels_first (bool): Pad for channels first tensors. Default: `True`
"""
def __init__(self, max_len, fill_value=0):
def __init__(self, max_len, fill_value=0, channels_first=True):
self.max_len = max_len
self.fill_value = fill_value
self.len_dim, self.ch_dim = int(channels_first), int(not channels_first)
def __call__(self, tensor):
"""
Returns:
Tensor: (max_len x Channels)
Tensor: (c x Ln or (n x c)
"""
if self.max_len > tensor.size(0):
pad = torch.ones((self.max_len - tensor.size(0),
tensor.size(1))) * self.fill_value
pad = pad.type_as(tensor)
tensor = torch.cat((tensor, pad), dim=0)
elif self.max_len < tensor.size(0):
tensor = tensor[:self.max_len, :]
assert tensor.size(self.ch_dim) < 128, \
"Too many channels ({}) detected, look at channels_first param.".format(tensor.size(self.ch_dim))
if self.max_len > tensor.size(self.len_dim):
padding_size = [self.max_len - tensor.size(self.len_dim) if i == self.len_dim
else tensor.size(self.ch_dim)
for i in range(2)]
pad = torch.empty(padding_size, dtype=tensor.dtype).fill_(self.fill_value)
tensor = torch.cat((tensor, pad), dim=self.len_dim)
elif self.max_len < tensor.size(self.len_dim):
tensor = tensor.narrow(self.len_dim, 0, self.max_len)
return tensor
def __repr__(self):
......@@ -122,25 +122,26 @@ class PadTrim(object):
class DownmixMono(object):
"""Downmix any stereo signals to mono
"""Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with
the `channels` effect instead of this transformation.
Inputs:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
channels_first (bool): Downmix across channels dimension. Default: `True`
Returns:
tensor (Tensor) (Samples x 1):
"""
def __init__(self):
pass
def __init__(self, channels_first=None):
self.ch_dim = int(not channels_first)
def __call__(self, tensor):
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
tensor = tensor.float()
if tensor.size(1) > 1:
tensor = torch.mean(tensor.float(), 1, True)
tensor = torch.mean(tensor, self.ch_dim, True)
return tensor
def __repr__(self):
......@@ -148,8 +149,7 @@ class DownmixMono(object):
class LC2CL(object):
"""Permute a 2d tensor from samples (Length) x Channels to Channels x
samples (Length)
"""Permute a 2d tensor from samples (n x c) to (c x n)
"""
def __call__(self, tensor):
......@@ -162,7 +162,6 @@ class LC2CL(object):
tensor (Tensor): Tensor of audio signal with shape (CxL)
"""
return tensor.transpose(0, 1).contiguous()
def __repr__(self):
......@@ -292,7 +291,7 @@ class SPEC2DB(object):
def __call__(self, spec):
spec, is_variable = _check_is_variable(spec)
spec_db = self.multiplier * _tlog10(spec / spec.max()) # power -> dB
spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
return spec_db if is_variable else spec_db.data
......@@ -320,7 +319,6 @@ class MEL2(object):
Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> sig = transforms.LC2CL()(sig) # (n, c) -> (c, n)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
......@@ -406,8 +404,8 @@ class MEL(object):
class BLC2CBL(object):
"""Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x
Bands x samples (Length)
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length
"""
def __call__(self, tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment