"src/vscode:/vscode.git/clone" did not exist on "fb07c307a1389a6a92a6e2e21dc243418f92ed82"
Commit 0e0d1e59 authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

docs update and fixes from pr comments

parent 8647f903
sphinx
sphinxcontrib-googleanalytics
-e git://github.com/snide/sphinx_rtd_theme.git#egg=sphinx_rtd_theme
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinxcontrib.katex
matplotlib
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 21.0.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<!-- Generator: Adobe Illustrator 22.1.0, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 199.7 40.2" style="enable-background:new 0 0 199.7 40.2;" xml:space="preserve">
<style type="text/css">
.st0{fill:#F05732;}
.st1{fill:#9E529F;}
.st2{fill:#333333;}
.st0{fill:#EE4C2C;}
.st1{fill:#252525;}
</style>
<path class="st0" d="M102.7,12.2c-1.3-1-1.8,3.9-4.4,3.9c-3,0-4-13-6.3-13c-0.7,0-0.8-0.4-7.9,21.3c-2.9,9,4.4,15.8,11.8,15.8
c4.6,0,12.3-3,12.3-12.6C108.2,20.5,104.7,13.7,102.7,12.2z M95.8,35.3c-3.7,0-6.7-3.1-6.7-7c0-3.9,3-7,6.7-7s6.7,3.1,6.7,7
C102.5,32.1,99.5,35.3,95.8,35.3z"/>
<path class="st1" d="M99.8,0c-0.5,0-1.8,2.5-1.8,3.6c0,1.5,1,2,1.8,2c0.8,0,1.8-0.5,1.8-2C101.5,2.5,100.2,0,99.8,0z"/>
<path class="st2" d="M0,39.5V14.9h11.5c5.3,0,8.3,3.6,8.3,7.9c0,4.3-3,7.9-8.3,7.9H5.2v8.8H0z M14.4,22.8c0-2.1-1.6-3.3-3.7-3.3H5.2
v6.6h5.5C12.8,26.1,14.4,24.8,14.4,22.8z"/>
<path class="st2" d="M35.2,39.5V29.4l-9.4-14.5h6l6.1,9.8l6.1-9.8h5.9l-9.4,14.5v10.1H35.2z"/>
<path class="st2" d="M63.3,39.5v-20h-7.2v-4.6h19.6v4.6h-7.2v20H63.3z"/>
<path class="st2" d="M131.4,39.5l-4.8-8.7h-3.8v8.7h-5.2V14.9H129c5.1,0,8.3,3.4,8.3,7.9c0,4.3-2.8,6.7-5.4,7.3l5.6,9.4H131.4z
M131.9,22.8c0-2-1.6-3.3-3.7-3.3h-5.5v6.6h5.5C130.3,26.1,131.9,24.9,131.9,22.8z"/>
<path class="st2" d="M145.6,27.2c0-7.6,5.7-12.7,13.1-12.7c5.4,0,8.5,2.9,10.3,6l-4.5,2.2c-1-2-3.2-3.6-5.8-3.6
c-4.5,0-7.7,3.4-7.7,8.1c0,4.6,3.2,8.1,7.7,8.1c2.5,0,4.7-1.6,5.8-3.6l4.5,2.2c-1.7,3.1-4.9,6-10.3,6
C151.3,39.9,145.6,34.7,145.6,27.2z"/>
<path class="st2" d="M194.5,39.5V29.1h-11.6v10.4h-5.2V14.9h5.2v9.7h11.6v-9.7h5.3v24.6H194.5z"/>
<g>
<path class="st0" d="M40.8,9.3l-2.1,2.1c3.5,3.5,3.5,9.2,0,12.7c-3.5,3.5-9.2,3.5-12.7,0c-3.5-3.5-3.5-9.2,0-12.7l0,0l5.6-5.6
L32.3,5l0,0V0.8l-8.5,8.5c-4.7,4.7-4.7,12.2,0,16.9s12.2,4.7,16.9,0C45.5,21.5,45.5,13.9,40.8,9.3z"/>
<circle class="st0" cx="36.6" cy="7.1" r="1.6"/>
</g>
<g>
<g>
<path class="st1" d="M62.6,20l-3.6,0v9.3h-2.7V2.9c0,0,6.3,0,6.6,0c7,0,10.3,3.4,10.3,8.3C73.2,17,69.1,19.9,62.6,20z M62.8,5.4
c-0.3,0-3.9,0-3.9,0v12.1l3.8-0.1c5-0.1,7.7-2.1,7.7-6.2C70.4,7.5,67.8,5.4,62.8,5.4z"/>
<path class="st1" d="M85.4,29.2l-1.6,4.2c-1.8,4.7-3.6,6.1-6.3,6.1c-1.5,0-2.6-0.4-3.8-0.9l0.8-2.4c0.9,0.5,1.9,0.8,3,0.8
c1.5,0,2.6-0.8,4-4.5l1.3-3.4L75.3,10h2.8l6.1,16l6-16h2.7L85.4,29.2z"/>
<path class="st1" d="M101.9,5.5v23.9h-2.7V5.5h-9.3V2.9h21.3v2.5H101.9z"/>
<path class="st1" d="M118.8,29.9c-5.4,0-9.4-4-9.4-10.2c0-6.2,4.1-10.3,9.6-10.3c5.4,0,9.3,4,9.3,10.2
C128.3,25.8,124.2,29.9,118.8,29.9z M118.9,11.8c-4.1,0-6.8,3.3-6.8,7.8c0,4.7,2.8,7.9,6.9,7.9s6.8-3.3,6.8-7.8
C125.8,15,123,11.8,118.9,11.8z"/>
<path class="st1" d="M135,29.4h-2.6V10l2.6-0.5v4.1c1.3-2.5,3.2-4.1,5.7-4.1c1.3,0,2.5,0.4,3.4,0.9l-0.7,2.5
c-0.8-0.5-1.9-0.8-3-0.8c-2,0-3.9,1.5-5.5,5V29.4z"/>
<path class="st1" d="M154.4,29.9c-5.8,0-9.5-4.2-9.5-10.2c0-6.1,4-10.3,9.5-10.3c2.4,0,4.4,0.6,6.1,1.7l-0.7,2.4
c-1.5-1-3.3-1.6-5.4-1.6c-4.2,0-6.8,3.1-6.8,7.7c0,4.7,2.8,7.8,6.9,7.8c1.9,0,3.9-0.6,5.4-1.6l0.5,2.4
C158.7,29.3,156.6,29.9,154.4,29.9z"/>
<path class="st1" d="M176.7,29.4V16.9c0-3.4-1.4-4.9-4.1-4.9c-2.2,0-4.4,1.1-6,2.8v14.7h-2.6V0.9l2.6-0.5c0,0,0,12.1,0,12.2
c2-2,4.6-3.1,6.7-3.1c3.8,0,6.1,2.4,6.1,6.6v13.3H176.7z"/>
</g>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
height="40.200001"
width="40.200001"
xml:space="preserve"
viewBox="0 0 40.200002 40.2"
y="0px"
x="0px"
id="Layer_1"
version="1.1"><metadata
id="metadata4717"><rdf:RDF><cc:Work
rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
id="defs4715" /><style
id="style4694"
type="text/css">
.st0{fill:#F05732;}
.st1{fill:#9E529F;}
.st2{fill:#333333;}
</style><path
style="fill:#f05732"
id="path4696"
d="m 26.975479,12.199999 c -1.3,-1 -1.8,3.9 -4.4,3.9 -3,0 -4,-12.9999998 -6.3,-12.9999998 -0.7,0 -0.8,-0.4 -7.9000003,21.2999998 -2.9000001,9 4.4000003,15.8 11.8000003,15.8 4.6,0 12.3,-3 12.3,-12.6 0,-7.1 -3.5,-13.9 -5.5,-15.4 z m -6.9,23.1 c -3.7,0 -6.7,-3.1 -6.7,-7 0,-3.9 3,-7 6.7,-7 3.7,0 6.7,3.1 6.7,7 0,3.8 -3,7 -6.7,7 z"
class="st0" /><path
style="fill:#9e529f"
id="path4698"
d="m 24.075479,-7.6293945e-7 c -0.5,0 -1.8,2.49999996293945 -1.8,3.59999996293945 0,1.5 1,2 1.8,2 0.8,0 1.8,-0.5 1.8,-2 -0.1,-1.1 -1.4,-3.59999996293945 -1.8,-3.59999996293945 z"
class="st1" /></svg>
\ No newline at end of file
......@@ -22,14 +22,13 @@
# sys.path.insert(0, os.path.abspath('.'))
import torch
import torchaudio
import sphinx_rtd_theme
import pytorch_sphinx_theme
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
needs_sphinx = '1.6'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
......@@ -41,16 +40,24 @@ extensions = [
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinxcontrib.googleanalytics',
'sphinxcontrib.katex',
]
napoleon_use_ivar = True
# katex options
#
#
katex_options = r'''
delimiters : [
{left: "$$", right: "$$", display: true},
{left: "\\(", right: "\\)", display: false},
{left: "\\[", right: "\\]", display: true}
]
'''
googleanalytics_id = 'UA-90545585-1'
googleanalytics_enabled = True
napoleon_use_ivar = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
......@@ -66,8 +73,8 @@ master_doc = 'index'
# General information about the project.
project = 'Torchaudio'
copyright = '2017, Torch Contributors'
author = 'Torch Contributors'
copyright = '2018, Torchaudio Contributors'
author = 'Torchaudio Contributors'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
......@@ -104,14 +111,15 @@ todo_include_todos = True
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
html_theme = 'pytorch_sphinx_theme'
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
html_theme_options = {
'pytorch_project': 'audio',
'collapse_navigation': False,
'display_version': True,
'logo_only': True,
......@@ -124,19 +132,25 @@ html_logo = '_static/img/pytorch-logo-dark.svg'
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# html_style_path = 'css/pytorch_theme.css'
html_context = {
'css_files': [
'https://fonts.googleapis.com/css?family=Lato',
'_static/css/pytorch_theme.css'
],
}
def setup(app):
# NOTE: in Sphinx 1.8+ `html_css_files` is an official configuration value
# and can be moved outside of this function (and the setup(app) function
# can be deleted).
html_css_files = [
'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css'
]
# In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is
# `add_stylesheet` (deprecated in 1.8).
add_css = getattr(app, 'add_css_file', getattr(app, 'add_stylesheet'))
for css_file in html_css_files:
add_css(css_file)
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'PyTorchdoc'
htmlhelp_basename = 'TorchAudiodoc'
# -- Options for LaTeX output ---------------------------------------------
......@@ -163,7 +177,7 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'pytorch.tex', 'torchaudio Documentation',
(master_doc, 'pytorch.tex', 'Torchaudio Documentation',
'Torch Contributors', 'manual'),
]
......@@ -173,7 +187,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchaudio', 'torchaudio Documentation',
(master_doc, 'Torchaudio', 'Torchaudio Documentation',
[author], 1)
]
......@@ -184,8 +198,8 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'torchaudio', 'torchaudio Documentation',
author, 'torchaudio', 'One line description of project.',
(master_doc, 'Torchaudio', 'Torchaudio Documentation',
author, 'Torchaudio', 'Load audio files into pytorch tensors.',
'Miscellaneous'),
]
......@@ -193,7 +207,7 @@ texinfo_documents = [
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
}
# -- A patch that prevents Sphinx from cross-referencing ivar tags -------
......@@ -246,5 +260,4 @@ def patched_make_field(self, types, domain, items, **kw):
fieldbody = nodes.field_body('', bodynode)
return nodes.field('', fieldname, fieldbody)
TypedField.make_field = patched_make_field
......@@ -18,7 +18,7 @@ class TORCHAUDIODS(Dataset):
self.si.precision = 16
self.E = torchaudio.sox_effects.SoxEffectsChain()
self.E.append_effect_to_chain("rate", [self.si.rate]) # resample to 16000hz
self.E.append_effect_to_chain("channels", [self.si.channels]) # mono singal
self.E.append_effect_to_chain("channels", [self.si.channels]) # mono signal
self.E.append_effect_to_chain("trim", [0, "16000s"]) # first 16000 samples of audio
def __getitem__(self, index):
......@@ -30,7 +30,7 @@ class TORCHAUDIODS(Dataset):
def __len__(self):
return len(self.data)
class Test_LoadSave(unittest.TestCase):
class Test_DataLoader(unittest.TestCase):
def test_1(self):
expected_size = (2, 1, 16000)
ds = TORCHAUDIODS()
......
......@@ -20,15 +20,13 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone()
result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1.,
print("min: {}, max: {}".format(result.min(), result.max())))
self.assertTrue(result.min() >= -1. and result.max() <= 1.)
maxminmax = np.abs(
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1.,
print("min: {}, max: {}".format(result.min(), result.max())))
result.min() >= -1. and result.max() <= 1.)
repr_test = transforms.Scale()
repr_test.__repr__()
......@@ -39,21 +37,19 @@ class Tester(unittest.TestCase):
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new)(audio_orig)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
self.assertTrue(result.size(0) == length_new,
print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0))))
self.assertEqual(result.size(0), length_new)
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 0.8)
result = transforms.PadTrim(max_len=length_new)(audio_orig)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
self.assertTrue(result.size(0) == length_new,
print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0))))
self.assertEqual(result.size(0), length_new)
repr_test = transforms.PadTrim(max_len=length_new)
repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
repr_test.__repr__()
def test_downmix_mono(self):
......@@ -67,11 +63,11 @@ class Tester(unittest.TestCase):
self.assertTrue(audio_Stereo.size(1) == 2)
result = transforms.DownmixMono()(audio_Stereo)
result = transforms.DownmixMono(channels_first=False)(audio_Stereo)
self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono()
repr_test = transforms.DownmixMono(channels_first=False)
repr_test.__repr__()
def test_lc2cl(self):
......@@ -107,7 +103,7 @@ class Tester(unittest.TestCase):
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new))
transforms.PadTrim(max_len=length_new, channels_first=False))
result = transforms.Compose(tset)(audio_orig)
self.assertTrue(np.abs([result.min(), result.max()]).max() == 1.)
......
......@@ -34,17 +34,18 @@ def load(filepath,
If `callable`, then the output is passed as a parameter
to the given function, then the output is divided by
the result.
channels_first (bool): Set channels first or length first in result. Default: ``True``
num_frames (int, optional): number of frames to load. 0 to load everything after the offset.
offset (int, optional): number of frames from the start of the file to begin data loading.
signalinfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determine
audio type cannot be automatically determined
encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels
- int: the sample-rate of the audio (as listed in the metadata of the file)
- int: the sample rate of the audio (as listed in the metadata of the file)
Example::
......@@ -113,8 +114,9 @@ def save_encinfo(filepath,
filepath (string): path to audio file
src (Tensor): an input 2D Tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
channels_first (bool): Set channels first or length first in result. Default: ``True``
signalinfo (sox_signalinfo_t): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determine
audio type cannot be automatically determined
encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
......
......@@ -28,6 +28,48 @@ def SoxEffect():
class SoxEffectsChain(object):
"""SoX effects chain class.
Args:
normalization (bool, number, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes signed 32-bit audio), and normalizes to `[0, 1]`.
If `number`, then output is divided by that number
If `callable`, then the output is passed as a parameter
to the given function, then the output is divided by
the result.
channels_first (bool, optional): Set channels first or length first in result. Default: ``True``
out_siginfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined
out_encinfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels
- int: the sample rate of the audio (as listed in the metadata of the file)
Example::
class MyDataset(Dataset):
def __init__(self, audiodir_path):
self.data = [fn for fn in os.listdir(audiodir_path)]
self.E = torchaudio.sox_effects.SoxEffectsChain()
self.E.append_effect_to_chain("rate", [16000]) # resample to 16000hz
self.E.append_effect_to_chain("channels", ["1"]) # mono signal
def __getitem__(self, index):
fn = self.data[index]
self.E.set_input_file(fn)
x, sr = self.E.sox_build_flow_effects()
return x, sr
def __len__(self):
return len(self.data)
>>> torchaudio.initialize_sox()
>>> ds = MyDataset(path_to_audio_files)
>>> for sig, sr in ds:
>>> [do something here]
>>> torchaudio.shutdown_sox()
"""
EFFECTS_AVAILABLE = set(effect_names())
......
......@@ -20,12 +20,6 @@ def _check_is_variable(tensor):
return tensor, is_variable
def _tlog10(x):
"""Pytorch Log10
"""
return torch.log(x) / torch.log(x.new([10]))
class Compose(object):
"""Composes several transforms together.
......@@ -92,29 +86,35 @@ class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels)
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
max_len (int): Length to which the tensor will be padded
channels_first (bool): Pad for channels first tensors. Default: `True`
"""
def __init__(self, max_len, fill_value=0):
def __init__(self, max_len, fill_value=0, channels_first=True):
self.max_len = max_len
self.fill_value = fill_value
self.len_dim, self.ch_dim = int(channels_first), int(not channels_first)
def __call__(self, tensor):
"""
Returns:
Tensor: (max_len x Channels)
Tensor: (c x Ln or (n x c)
"""
if self.max_len > tensor.size(0):
pad = torch.ones((self.max_len - tensor.size(0),
tensor.size(1))) * self.fill_value
pad = pad.type_as(tensor)
tensor = torch.cat((tensor, pad), dim=0)
elif self.max_len < tensor.size(0):
tensor = tensor[:self.max_len, :]
assert tensor.size(self.ch_dim) < 128, \
"Too many channels ({}) detected, look at channels_first param.".format(tensor.size(self.ch_dim))
if self.max_len > tensor.size(self.len_dim):
padding_size = [self.max_len - tensor.size(self.len_dim) if i == self.len_dim
else tensor.size(self.ch_dim)
for i in range(2)]
pad = torch.empty(padding_size, dtype=tensor.dtype).fill_(self.fill_value)
tensor = torch.cat((tensor, pad), dim=self.len_dim)
elif self.max_len < tensor.size(self.len_dim):
tensor = tensor.narrow(self.len_dim, 0, self.max_len)
return tensor
def __repr__(self):
......@@ -122,25 +122,26 @@ class PadTrim(object):
class DownmixMono(object):
"""Downmix any stereo signals to mono
"""Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with
the `channels` effect instead of this transformation.
Inputs:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
channels_first (bool): Downmix across channels dimension. Default: `True`
Returns:
tensor (Tensor) (Samples x 1):
"""
def __init__(self):
pass
def __init__(self, channels_first=None):
self.ch_dim = int(not channels_first)
def __call__(self, tensor):
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
tensor = tensor.float()
if tensor.size(1) > 1:
tensor = torch.mean(tensor.float(), 1, True)
tensor = torch.mean(tensor, self.ch_dim, True)
return tensor
def __repr__(self):
......@@ -148,8 +149,7 @@ class DownmixMono(object):
class LC2CL(object):
"""Permute a 2d tensor from samples (Length) x Channels to Channels x
samples (Length)
"""Permute a 2d tensor from samples (n x c) to (c x n)
"""
def __call__(self, tensor):
......@@ -162,7 +162,6 @@ class LC2CL(object):
tensor (Tensor): Tensor of audio signal with shape (CxL)
"""
return tensor.transpose(0, 1).contiguous()
def __repr__(self):
......@@ -292,7 +291,7 @@ class SPEC2DB(object):
def __call__(self, spec):
spec, is_variable = _check_is_variable(spec)
spec_db = self.multiplier * _tlog10(spec / spec.max()) # power -> dB
spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
return spec_db if is_variable else spec_db.data
......@@ -320,7 +319,6 @@ class MEL2(object):
Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> sig = transforms.LC2CL()(sig) # (n, c) -> (c, n)
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
......@@ -406,8 +404,8 @@ class MEL(object):
class BLC2CBL(object):
"""Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x
Bands x samples (Length)
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length
"""
def __call__(self, tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment