Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
"""Test suites for checking numerical compatibility against Kaldi""" """Test suites for checking numerical compatibility against Kaldi"""
import torchaudio.compliance.kaldi import torchaudio.compliance.kaldi
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
TempDirMixin, TempDirMixin,
...@@ -21,35 +20,35 @@ class Kaldi(TempDirMixin, TestBaseMixin): ...@@ -21,35 +20,35 @@ class Kaldi(TempDirMixin, TestBaseMixin):
expected = expected.to(dtype=self.dtype, device=self.device) expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol) self.assertEqual(output, expected, rtol=rtol, atol=atol)
@parameterized.expand(load_params('kaldi_test_fbank_args.jsonl')) @parameterized.expand(load_params("kaldi_test_fbank_args.jsonl"))
@skipIfNoExec('compute-fbank-feats') @skipIfNoExec("compute-fbank-feats")
def test_fbank(self, kwargs): def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats""" """fbank should be numerically compatible with compute-fbank-feats"""
wave_file = get_asset_path('kaldi_file.wav') wave_file = get_asset_path("kaldi_file.wav")
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs) result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
command = ['compute-fbank-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ["compute-fbank-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, 'scp', wave_file) kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_spectrogram_args.jsonl')) @parameterized.expand(load_params("kaldi_test_spectrogram_args.jsonl"))
@skipIfNoExec('compute-spectrogram-feats') @skipIfNoExec("compute-spectrogram-feats")
def test_spectrogram(self, kwargs): def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats""" """spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file = get_asset_path('kaldi_file.wav') wave_file = get_asset_path("kaldi_file.wav")
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs) result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
command = ['compute-spectrogram-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ["compute-spectrogram-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, 'scp', wave_file) kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_mfcc_args.jsonl')) @parameterized.expand(load_params("kaldi_test_mfcc_args.jsonl"))
@skipIfNoExec('compute-mfcc-feats') @skipIfNoExec("compute-mfcc-feats")
def test_mfcc(self, kwargs): def test_mfcc(self, kwargs):
"""mfcc should be numerically compatible with compute-mfcc-feats""" """mfcc should be numerically compatible with compute-mfcc-feats"""
wave_file = get_asset_path('kaldi_file.wav') wave_file = get_asset_path("kaldi_file.wav")
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs) result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
command = ['compute-mfcc-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ["compute-mfcc-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, 'scp', wave_file) kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .librosa_compatibility_test_impl import TransformsTestBase from .librosa_compatibility_test_impl import TransformsTestBase
class TestTransforms(TransformsTestBase, PytorchTestCase): class TestTransforms(TransformsTestBase, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device("cpu")
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .librosa_compatibility_test_impl import TransformsTestBase from .librosa_compatibility_test_impl import TransformsTestBase
@skipIfNoCuda @skipIfNoCuda
class TestTransforms(TransformsTestBase, PytorchTestCase): class TestTransforms(TransformsTestBase, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device("cuda")
...@@ -2,9 +2,8 @@ import unittest ...@@ -2,9 +2,8 @@ import unittest
import torch import torch
import torchaudio.transforms as T import torchaudio.transforms as T
from torchaudio._internal.module_utils import is_module_available
from parameterized import param, parameterized from parameterized import param, parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
get_whitenoise, get_whitenoise,
...@@ -13,7 +12,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -13,7 +12,7 @@ from torchaudio_unittest.common_utils import (
nested_params, nested_params,
) )
LIBROSA_AVAILABLE = is_module_available('librosa') LIBROSA_AVAILABLE = is_module_available("librosa")
if LIBROSA_AVAILABLE: if LIBROSA_AVAILABLE:
import librosa import librosa
...@@ -21,25 +20,28 @@ if LIBROSA_AVAILABLE: ...@@ -21,25 +20,28 @@ if LIBROSA_AVAILABLE:
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TransformsTestBase(TestBaseMixin): class TransformsTestBase(TestBaseMixin):
@parameterized.expand([ @parameterized.expand(
param(n_fft=400, hop_length=200, power=2.0), [
param(n_fft=600, hop_length=100, power=2.0), param(n_fft=400, hop_length=200, power=2.0),
param(n_fft=400, hop_length=200, power=3.0), param(n_fft=600, hop_length=100, power=2.0),
param(n_fft=200, hop_length=50, power=2.0), param(n_fft=400, hop_length=200, power=3.0),
]) param(n_fft=200, hop_length=50, power=2.0),
]
)
def test_Spectrogram(self, n_fft, hop_length, power): def test_Spectrogram(self, n_fft, hop_length, power):
sample_rate = 16000 sample_rate = 16000
waveform = get_whitenoise( waveform = get_whitenoise(
sample_rate=sample_rate, n_channels=1, sample_rate=sample_rate,
n_channels=1,
).to(self.device, self.dtype) ).to(self.device, self.dtype)
expected = librosa.core.spectrum._spectrogram( expected = librosa.core.spectrum._spectrogram(
y=waveform[0].cpu().numpy(), y=waveform[0].cpu().numpy(), n_fft=n_fft, hop_length=hop_length, power=power
n_fft=n_fft, hop_length=hop_length, power=power)[0] )[0]
result = T.Spectrogram( result = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=power,).to(self.device, self.dtype)(
n_fft=n_fft, hop_length=hop_length, power=power, waveform
).to(self.device, self.dtype)(waveform)[0] )[0]
self.assertEqual(result, torch.from_numpy(expected), atol=1e-5, rtol=1e-5) self.assertEqual(result, torch.from_numpy(expected), atol=1e-5, rtol=1e-5)
def test_Spectrogram_complex(self): def test_Spectrogram_complex(self):
...@@ -47,16 +49,17 @@ class TransformsTestBase(TestBaseMixin): ...@@ -47,16 +49,17 @@ class TransformsTestBase(TestBaseMixin):
hop_length = 200 hop_length = 200
sample_rate = 16000 sample_rate = 16000
waveform = get_whitenoise( waveform = get_whitenoise(
sample_rate=sample_rate, n_channels=1, sample_rate=sample_rate,
n_channels=1,
).to(self.device, self.dtype) ).to(self.device, self.dtype)
expected = librosa.core.spectrum._spectrogram( expected = librosa.core.spectrum._spectrogram(
y=waveform[0].cpu().numpy(), y=waveform[0].cpu().numpy(), n_fft=n_fft, hop_length=hop_length, power=1
n_fft=n_fft, hop_length=hop_length, power=1)[0] )[0]
result = T.Spectrogram( result = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True,).to(
n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True, self.device, self.dtype
).to(self.device, self.dtype)(waveform)[0] )(waveform)[0]
self.assertEqual(result.abs(), torch.from_numpy(expected), atol=1e-5, rtol=1e-5) self.assertEqual(result.abs(), torch.from_numpy(expected), atol=1e-5, rtol=1e-5)
@nested_params( @nested_params(
...@@ -65,77 +68,95 @@ class TransformsTestBase(TestBaseMixin): ...@@ -65,77 +68,95 @@ class TransformsTestBase(TestBaseMixin):
param(n_fft=600, hop_length=100, n_mels=128), param(n_fft=600, hop_length=100, n_mels=128),
param(n_fft=200, hop_length=50, n_mels=32), param(n_fft=200, hop_length=50, n_mels=32),
], ],
[param(norm=norm) for norm in [None, 'slaney']], [param(norm=norm) for norm in [None, "slaney"]],
[param(mel_scale=mel_scale) for mel_scale in ['htk', 'slaney']], [param(mel_scale=mel_scale) for mel_scale in ["htk", "slaney"]],
) )
def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale): def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale):
sample_rate = 16000 sample_rate = 16000
waveform = get_sinusoid( waveform = get_sinusoid(
sample_rate=sample_rate, n_channels=1, sample_rate=sample_rate,
n_channels=1,
).to(self.device, self.dtype) ).to(self.device, self.dtype)
expected = librosa.feature.melspectrogram( expected = librosa.feature.melspectrogram(
y=waveform[0].cpu().numpy(), y=waveform[0].cpu().numpy(),
sr=sample_rate, n_fft=n_fft, sr=sample_rate,
hop_length=hop_length, n_mels=n_mels, norm=norm, n_fft=n_fft,
htk=mel_scale == "htk") hop_length=hop_length,
n_mels=n_mels,
norm=norm,
htk=mel_scale == "htk",
)
result = T.MelSpectrogram( result = T.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window, sample_rate=sample_rate,
hop_length=hop_length, n_mels=n_mels, window_fn=torch.hann_window,
n_fft=n_fft, norm=norm, mel_scale=mel_scale, hop_length=hop_length,
n_mels=n_mels,
n_fft=n_fft,
norm=norm,
mel_scale=mel_scale,
).to(self.device, self.dtype)(waveform)[0] ).to(self.device, self.dtype)(waveform)[0]
self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5)
def test_magnitude_to_db(self): def test_magnitude_to_db(self):
spectrogram = get_spectrogram( spectrogram = get_spectrogram(get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype)
get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype) result = T.AmplitudeToDB("magnitude", 80.0).to(self.device, self.dtype)(spectrogram)[0]
result = T.AmplitudeToDB('magnitude', 80.).to(self.device, self.dtype)(spectrogram)[0]
expected = librosa.core.spectrum.amplitude_to_db(spectrogram[0].cpu().numpy()) expected = librosa.core.spectrum.amplitude_to_db(spectrogram[0].cpu().numpy())
self.assertEqual(result, torch.from_numpy(expected)) self.assertEqual(result, torch.from_numpy(expected))
def test_power_to_db(self): def test_power_to_db(self):
spectrogram = get_spectrogram( spectrogram = get_spectrogram(get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype)
get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype) result = T.AmplitudeToDB("power", 80.0).to(self.device, self.dtype)(spectrogram)[0]
result = T.AmplitudeToDB('power', 80.).to(self.device, self.dtype)(spectrogram)[0]
expected = librosa.core.spectrum.power_to_db(spectrogram[0].cpu().numpy()) expected = librosa.core.spectrum.power_to_db(spectrogram[0].cpu().numpy())
self.assertEqual(result, torch.from_numpy(expected)) self.assertEqual(result, torch.from_numpy(expected))
@nested_params([ @nested_params(
param(n_fft=400, hop_length=200, n_mels=64, n_mfcc=40), [
param(n_fft=600, hop_length=100, n_mels=128, n_mfcc=20), param(n_fft=400, hop_length=200, n_mels=64, n_mfcc=40),
param(n_fft=200, hop_length=50, n_mels=32, n_mfcc=25), param(n_fft=600, hop_length=100, n_mels=128, n_mfcc=20),
]) param(n_fft=200, hop_length=50, n_mels=32, n_mfcc=25),
]
)
def test_mfcc(self, n_fft, hop_length, n_mels, n_mfcc): def test_mfcc(self, n_fft, hop_length, n_mels, n_mfcc):
sample_rate = 16000 sample_rate = 16000
waveform = get_whitenoise( waveform = get_whitenoise(sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype)
sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype)
result = T.MFCC( result = T.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', sample_rate=sample_rate,
melkwargs={'hop_length': hop_length, 'n_fft': n_fft, 'n_mels': n_mels}, n_mfcc=n_mfcc,
norm="ortho",
melkwargs={"hop_length": hop_length, "n_fft": n_fft, "n_mels": n_mels},
).to(self.device, self.dtype)(waveform)[0] ).to(self.device, self.dtype)(waveform)[0]
melspec = librosa.feature.melspectrogram( melspec = librosa.feature.melspectrogram(
y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft, y=waveform[0].cpu().numpy(),
win_length=n_fft, hop_length=hop_length, sr=sample_rate,
n_mels=n_mels, htk=True, norm=None) n_fft=n_fft,
win_length=n_fft,
hop_length=hop_length,
n_mels=n_mels,
htk=True,
norm=None,
)
expected = librosa.feature.mfcc( expected = librosa.feature.mfcc(
S=librosa.core.spectrum.power_to_db(melspec), S=librosa.core.spectrum.power_to_db(melspec), n_mfcc=n_mfcc, dct_type=2, norm="ortho"
n_mfcc=n_mfcc, dct_type=2, norm='ortho') )
self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5)
@parameterized.expand([ @parameterized.expand(
param(n_fft=400, hop_length=200), [
param(n_fft=600, hop_length=100), param(n_fft=400, hop_length=200),
param(n_fft=200, hop_length=50), param(n_fft=600, hop_length=100),
]) param(n_fft=200, hop_length=50),
]
)
def test_spectral_centroid(self, n_fft, hop_length): def test_spectral_centroid(self, n_fft, hop_length):
sample_rate = 16000 sample_rate = 16000
waveform = get_whitenoise( waveform = get_whitenoise(sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype)
sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype)
result = T.SpectralCentroid( result = T.SpectralCentroid(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,).to(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, self.device, self.dtype
).to(self.device, self.dtype)(waveform) )(waveform)
expected = librosa.feature.spectral_centroid( expected = librosa.feature.spectral_centroid(
y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft, hop_length=hop_length) y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft, hop_length=hop_length
)
self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5)
...@@ -3,7 +3,6 @@ import warnings ...@@ -3,7 +3,6 @@ import warnings
import torch import torch
import torchaudio.transforms as T import torchaudio.transforms as T
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
skipIfNoSox, skipIfNoSox,
skipIfNoExec, skipIfNoExec,
...@@ -18,10 +17,10 @@ from torchaudio_unittest.common_utils import ( ...@@ -18,10 +17,10 @@ from torchaudio_unittest.common_utils import (
@skipIfNoSox @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec("sox")
class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def run_sox_effect(self, input_file, effect): def run_sox_effect(self, input_file, effect):
output_file = self.get_temp_path('expected.wav') output_file = self.get_temp_path("expected.wav")
sox_utils.run_sox_effect(input_file, output_file, [str(e) for e in effect]) sox_utils.run_sox_effect(input_file, output_file, [str(e) for e in effect])
return load_wav(output_file) return load_wav(output_file)
...@@ -31,39 +30,45 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -31,39 +30,45 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def get_whitenoise(self, sample_rate=8000): def get_whitenoise(self, sample_rate=8000):
noise = get_whitenoise( noise = get_whitenoise(
sample_rate=sample_rate, duration=3, scale_factor=0.9, sample_rate=sample_rate,
duration=3,
scale_factor=0.9,
) )
path = self.get_temp_path("whitenoise.wav") path = self.get_temp_path("whitenoise.wav")
save_wav(path, noise, sample_rate) save_wav(path, noise, sample_rate)
return noise, path return noise, path
@parameterized.expand([ @parameterized.expand(
('q', 'quarter_sine'), [
('h', 'half_sine'), ("q", "quarter_sine"),
('t', 'linear'), ("h", "half_sine"),
]) ("t", "linear"),
]
)
def test_fade(self, fade_shape_sox, fade_shape): def test_fade(self, fade_shape_sox, fade_shape):
fade_in_len, fade_out_len = 44100, 44100 fade_in_len, fade_out_len = 44100, 44100
data, path = self.get_whitenoise(sample_rate=44100) data, path = self.get_whitenoise(sample_rate=44100)
result = T.Fade(fade_in_len, fade_out_len, fade_shape)(data) result = T.Fade(fade_in_len, fade_out_len, fade_shape)(data)
self.assert_sox_effect(result, path, ['fade', fade_shape_sox, '1', '0', '1']) self.assert_sox_effect(result, path, ["fade", fade_shape_sox, "1", "0", "1"])
@parameterized.expand([ @parameterized.expand(
('amplitude', 1.1), [
('db', 2), ("amplitude", 1.1),
('power', 2), ("db", 2),
]) ("power", 2),
]
)
def test_vol(self, gain_type, gain): def test_vol(self, gain_type, gain):
data, path = self.get_whitenoise() data, path = self.get_whitenoise()
result = T.Vol(gain, gain_type)(data) result = T.Vol(gain, gain_type)(data)
self.assert_sox_effect(result, path, ['vol', f'{gain}', gain_type]) self.assert_sox_effect(result, path, ["vol", f"{gain}", gain_type])
@parameterized.expand(['vad-go-stereo-44100.wav', 'vad-go-mono-32000.wav']) @parameterized.expand(["vad-go-stereo-44100.wav", "vad-go-mono-32000.wav"])
def test_vad(self, filename): def test_vad(self, filename):
path = get_asset_path(filename) path = get_asset_path(filename)
data, sample_rate = load_wav(path) data, sample_rate = load_wav(path)
result = T.Vad(sample_rate)(data) result = T.Vad(sample_rate)(data)
self.assert_sox_effect(result, path, ['vad']) self.assert_sox_effect(result, path, ["vad"])
def test_vad_warning(self): def test_vad_warning(self):
"""vad should throw a warning if input dimension is greater than 2""" """vad should throw a warning if input dimension is greater than 2"""
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device("cpu")
class TestTransformsFloat64(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device("cpu")
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
@skipIfNoCuda @skipIfNoCuda
class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device("cuda")
@skipIfNoCuda @skipIfNoCuda
class TestTransformsFloat64(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device("cuda")
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import torch import torch
import torchaudio.transforms as T import torchaudio.transforms as T
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
skipIfRocm, skipIfRocm,
...@@ -14,6 +13,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -14,6 +13,7 @@ from torchaudio_unittest.common_utils import (
class Transforms(TestBaseMixin): class Transforms(TestBaseMixin):
"""Implements test for Transforms that are performed for different devices""" """Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor, *args): def _assert_consistency(self, transform, tensor, *args):
tensor = tensor.to(device=self.device, dtype=self.dtype) tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype) transform = transform.to(device=self.device, dtype=self.dtype)
...@@ -139,10 +139,7 @@ class Transforms(TestBaseMixin): ...@@ -139,10 +139,7 @@ class Transforms(TestBaseMixin):
sample_rate = 8000 sample_rate = 8000
n_steps = 4 n_steps = 4
waveform = common_utils.get_whitenoise(sample_rate=sample_rate) waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency( self._assert_consistency(T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), waveform)
T.PitchShift(sample_rate=sample_rate, n_steps=n_steps),
waveform
)
def test_PSD(self): def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
...@@ -157,33 +154,34 @@ class Transforms(TestBaseMixin): ...@@ -157,33 +154,34 @@ class Transforms(TestBaseMixin):
mask = torch.rand(spectrogram.shape[-2:], device=self.device) mask = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.PSD(), spectrogram, mask) self._assert_consistency_complex(T.PSD(), spectrogram, mask)
@parameterized.expand([ @parameterized.expand(
["ref_channel", True], [
["stv_evd", True], ["ref_channel", True],
["stv_power", True], ["stv_evd", True],
["ref_channel", False], ["stv_power", True],
["stv_evd", False], ["ref_channel", False],
["stv_power", False], ["stv_evd", False],
]) ["stv_power", False],
]
)
def test_MVDR(self, solution, online): def test_MVDR(self, solution, online):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
mask_s = torch.rand(spectrogram.shape[-2:], device=self.device) mask_s = torch.rand(spectrogram.shape[-2:], device=self.device)
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device) mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex( self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n)
T.MVDR(solution=solution, online=online),
spectrogram, mask_s, mask_n
)
class TransformsFloat32Only(TestBaseMixin): class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self): def test_rnnt_loss(self):
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], logits = torch.tensor(
[0.1, 0.1, 0.6, 0.1, 0.1], [
[0.1, 0.1, 0.2, 0.8, 0.1]], [
[[0.1, 0.6, 0.1, 0.1, 0.1], [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]],
[0.1, 0.1, 0.2, 0.1, 0.1], [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]],
[0.7, 0.1, 0.2, 0.1, 0.1]]]]) ]
]
)
tensor = logits.to(device=self.device, dtype=torch.float32) tensor = logits.to(device=self.device, dtype=torch.float32)
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32) targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32) logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from . transforms_test_impl import TransformsTestBase
from .transforms_test_impl import TransformsTestBase
class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase): class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cpu' device = "cpu"
dtype = torch.float32 dtype = torch.float32
class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase): class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cpu' device = "cpu"
dtype = torch.float64 dtype = torch.float64
import torch import torch
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
PytorchTestCase, PytorchTestCase,
skipIfNoCuda, skipIfNoCuda,
) )
from . transforms_test_impl import TransformsTestBase
from .transforms_test_impl import TransformsTestBase
@skipIfNoCuda @skipIfNoCuda
class TransformsCUDAFloat32Test(TransformsTestBase, PytorchTestCase): class TransformsCUDAFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cuda' device = "cuda"
dtype = torch.float32 dtype = torch.float32
@skipIfNoCuda @skipIfNoCuda
class TransformsCUDAFloat64Test(TransformsTestBase, PytorchTestCase): class TransformsCUDAFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cuda' device = "cuda"
dtype = torch.float64 dtype = torch.float64
...@@ -2,24 +2,23 @@ import math ...@@ -2,24 +2,23 @@ import math
import torch import torch
import torchaudio import torchaudio
import torchaudio.transforms as transforms
import torchaudio.functional as F import torchaudio.functional as F
import torchaudio.transforms as transforms
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
class Tester(common_utils.TorchaudioTestCase): class Tester(common_utils.TorchaudioTestCase):
backend = 'default' backend = "default"
# create a sinewave signal for testing # create a sinewave signal for testing
sample_rate = 16000 sample_rate = 16000
freq = 440 freq = 440
volume = .3 volume = 0.3
waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)) waveform = torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)
waveform.unsqueeze_(0) # (1, 64000) waveform.unsqueeze_(0) # (1, 64000)
waveform = (waveform * volume * 2**31).long() waveform = (waveform * volume * 2 ** 31).long()
def scale(self, waveform, factor=2.0**31): def scale(self, waveform, factor=2.0 ** 31):
# scales a waveform by a factor # scales a waveform by a factor
if not waveform.is_floating_point(): if not waveform.is_floating_point():
waveform = waveform.to(torch.get_default_dtype()) waveform = waveform.to(torch.get_default_dtype())
...@@ -34,20 +33,20 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -34,20 +33,20 @@ class Tester(common_utils.TorchaudioTestCase):
waveform = waveform.to(torch.get_default_dtype()) waveform = waveform.to(torch.get_default_dtype())
waveform /= torch.abs(waveform).max() waveform /= torch.abs(waveform).max()
self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.) self.assertTrue(waveform.min() >= -1.0 and waveform.max() <= 1.0)
waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform) waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels) self.assertTrue(waveform_mu.min() >= 0.0 and waveform_mu.max() <= quantization_channels)
waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) self.assertTrue(waveform_exp.min() >= -1.0 and waveform_exp.max() <= 1.0)
def test_AmplitudeToDB(self): def test_AmplitudeToDB(self):
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') filepath = common_utils.get_asset_path("steam-train-whistle-daniel_simon.wav")
waveform = common_utils.load_wav(filepath)[0] waveform = common_utils.load_wav(filepath)[0]
mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.) mag_to_db_transform = transforms.AmplitudeToDB("magnitude", 80.0)
power_to_db_transform = transforms.AmplitudeToDB('power', 80.) power_to_db_transform = transforms.AmplitudeToDB("power", 80.0)
mag_to_db_torch = mag_to_db_transform(torch.abs(waveform)) mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2)) power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))
...@@ -88,8 +87,8 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -88,8 +87,8 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(fb, fb_copy) self.assertEqual(fb, fb_copy)
def test_mel2(self): def test_mel2(self):
top_db = 80. top_db = 80.0
s2db = transforms.AmplitudeToDB('power', top_db) s2db = transforms.AmplitudeToDB("power", top_db)
waveform = self.waveform.clone() # (1, 16000) waveform = self.waveform.clone() # (1, 16000)
waveform_scaled = self.scale(waveform) # (1, 16000) waveform_scaled = self.scale(waveform) # (1, 16000)
...@@ -100,20 +99,26 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -100,20 +99,26 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels) self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
# check correctness of filterbank conversion matrix # check correctness of filterbank conversion matrix
self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.0).all())
self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all()) self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.0).all())
# check options # check options
kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500, kwargs = {
'hop_length': 125, 'n_fft': 800, 'n_mels': 50} "window_fn": torch.hamming_window,
"pad": 10,
"win_length": 500,
"hop_length": 125,
"n_fft": 800,
"n_mels": 50,
}
mel_transform2 = transforms.MelSpectrogram(**kwargs) mel_transform2 = transforms.MelSpectrogram(**kwargs)
spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513) spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513)
self.assertTrue(spectrogram2_torch.dim() == 3) self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels) self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.0).all())
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.0).all())
# check on multi-channel audio # check on multi-channel audio
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') filepath = common_utils.get_asset_path("steam-train-whistle-daniel_simon.wav")
x_stereo = common_utils.load_wav(filepath)[0] # (2, 278756), 44100 x_stereo = common_utils.load_wav(filepath)[0] # (2, 278756), 44100
spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394) spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.dim() == 3)
...@@ -121,57 +126,46 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -121,57 +126,46 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels) self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
# check filterbank matrix creation # check filterbank matrix creation
fb_matrix_transform = transforms.MelScale( fb_matrix_transform = transforms.MelScale(n_mels=100, sample_rate=16000, f_min=0.0, f_max=None, n_stft=400)
n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400) self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.0).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.0).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_mfcc_defaults(self): def test_mfcc_defaults(self):
"""Check the default configuration of the MFCC transform. """Check the default configuration of the MFCC transform."""
"""
sample_rate = 16000 sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate) audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40 n_mfcc = 40
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm="ortho")
n_mfcc=n_mfcc,
norm='ortho')
torch_mfcc = mfcc_transform(audio) # (1, 40, 81) torch_mfcc = mfcc_transform(audio) # (1, 40, 81)
self.assertEqual(torch_mfcc.dim(), 3) self.assertEqual(torch_mfcc.dim(), 3)
self.assertEqual(torch_mfcc.shape[1], n_mfcc) self.assertEqual(torch_mfcc.shape[1], n_mfcc)
self.assertEqual(torch_mfcc.shape[2], 81) self.assertEqual(torch_mfcc.shape[2], 81)
def test_mfcc_kwargs_passthrough(self): def test_mfcc_kwargs_passthrough(self):
"""Check kwargs get correctly passed to the MelSpectrogram transform. """Check kwargs get correctly passed to the MelSpectrogram transform."""
"""
sample_rate = 16000 sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate) audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40 n_mfcc = 40
melkwargs = {'win_length': 200} melkwargs = {"win_length": 200}
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, mfcc_transform = torchaudio.transforms.MFCC(
n_mfcc=n_mfcc, sample_rate=sample_rate, n_mfcc=n_mfcc, norm="ortho", melkwargs=melkwargs
norm='ortho', )
melkwargs=melkwargs)
torch_mfcc = mfcc_transform(audio) # (1, 40, 161) torch_mfcc = mfcc_transform(audio) # (1, 40, 161)
self.assertEqual(torch_mfcc.shape[2], 161) self.assertEqual(torch_mfcc.shape[2], 161)
def test_mfcc_norms(self): def test_mfcc_norms(self):
"""Check if MFCC-DCT norms work correctly. """Check if MFCC-DCT norms work correctly."""
"""
sample_rate = 16000 sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate) audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40 n_mfcc = 40
n_mels = 128 n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm="ortho")
n_mfcc=n_mfcc,
norm='ortho')
# check norms work correctly # check norms work correctly
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate, mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm=None)
n_mfcc=n_mfcc,
norm=None)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio) # (1, 40, 81) torch_mfcc_norm_none = mfcc_transform_norm_none(audio) # (1, 40, 81)
norm_check = mfcc_transform(audio) norm_check = mfcc_transform(audio)
...@@ -181,56 +175,48 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -181,56 +175,48 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(torch_mfcc_norm_none, norm_check) self.assertEqual(torch_mfcc_norm_none, norm_check)
def test_lfcc_defaults(self): def test_lfcc_defaults(self):
"""Check default settings for LFCC transform. """Check default settings for LFCC transform."""
"""
sample_rate = 16000 sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate) audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_lfcc = 40 n_lfcc = 40
n_filter = 128 n_filter = 128
lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate, lfcc_transform = torchaudio.transforms.LFCC(
n_filter=n_filter, sample_rate=sample_rate, n_filter=n_filter, n_lfcc=n_lfcc, norm="ortho"
n_lfcc=n_lfcc, )
norm='ortho')
torch_lfcc = lfcc_transform(audio) # (1, 40, 81) torch_lfcc = lfcc_transform(audio) # (1, 40, 81)
self.assertEqual(torch_lfcc.dim(), 3) self.assertEqual(torch_lfcc.dim(), 3)
self.assertEqual(torch_lfcc.shape[1], n_lfcc) self.assertEqual(torch_lfcc.shape[1], n_lfcc)
self.assertEqual(torch_lfcc.shape[2], 81) self.assertEqual(torch_lfcc.shape[2], 81)
def test_lfcc_arg_passthrough(self): def test_lfcc_arg_passthrough(self):
"""Check if kwargs get correctly passed to the underlying Spectrogram transform. """Check if kwargs get correctly passed to the underlying Spectrogram transform."""
"""
sample_rate = 16000 sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate) audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_lfcc = 40 n_lfcc = 40
n_filter = 128 n_filter = 128
speckwargs = {'win_length': 200} speckwargs = {"win_length": 200}
lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate, lfcc_transform = torchaudio.transforms.LFCC(
n_filter=n_filter, sample_rate=sample_rate, n_filter=n_filter, n_lfcc=n_lfcc, norm="ortho", speckwargs=speckwargs
n_lfcc=n_lfcc, )
norm='ortho',
speckwargs=speckwargs)
torch_lfcc = lfcc_transform(audio) # (1, 40, 161) torch_lfcc = lfcc_transform(audio) # (1, 40, 161)
self.assertEqual(torch_lfcc.shape[2], 161) self.assertEqual(torch_lfcc.shape[2], 161)
def test_lfcc_norms(self): def test_lfcc_norms(self):
"""Check if LFCC-DCT norm works correctly. """Check if LFCC-DCT norm works correctly."""
"""
sample_rate = 16000 sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate) audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_lfcc = 40 n_lfcc = 40
n_filter = 128 n_filter = 128
lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate, lfcc_transform = torchaudio.transforms.LFCC(
n_filter=n_filter, sample_rate=sample_rate, n_filter=n_filter, n_lfcc=n_lfcc, norm="ortho"
n_lfcc=n_lfcc, )
norm='ortho')
lfcc_transform_norm_none = torchaudio.transforms.LFCC(
lfcc_transform_norm_none = torchaudio.transforms.LFCC(sample_rate=sample_rate, sample_rate=sample_rate, n_filter=n_filter, n_lfcc=n_lfcc, norm=None
n_filter=n_filter, )
n_lfcc=n_lfcc,
norm=None)
torch_lfcc_norm_none = lfcc_transform_norm_none(audio) # (1, 40, 161) torch_lfcc_norm_none = lfcc_transform_norm_none(audio) # (1, 40, 161)
norm_check = lfcc_transform(audio) # (1, 40, 161) norm_check = lfcc_transform(audio) # (1, 40, 161)
...@@ -240,26 +226,27 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -240,26 +226,27 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(torch_lfcc_norm_none, norm_check) self.assertEqual(torch_lfcc_norm_none, norm_check)
def test_resample_size(self): def test_resample_size(self):
input_path = common_utils.get_asset_path('sinewave.wav') input_path = common_utils.get_asset_path("sinewave.wav")
waveform, sample_rate = common_utils.load_wav(input_path) waveform, sample_rate = common_utils.load_wav(input_path)
upsample_rate = sample_rate * 2 upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2 downsample_rate = sample_rate // 2
invalid_resampling_method = 'foo' invalid_resampling_method = "foo"
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
torchaudio.transforms.Resample(sample_rate, upsample_rate, torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method=invalid_resampling_method)
resampling_method=invalid_resampling_method)
upsample_resample = torchaudio.transforms.Resample( upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method='sinc_interpolation') sample_rate, upsample_rate, resampling_method="sinc_interpolation"
)
up_sampled = upsample_resample(waveform) up_sampled = upsample_resample(waveform)
# we expect the upsampled signal to have twice as many samples # we expect the upsampled signal to have twice as many samples
self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2) self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
downsample_resample = torchaudio.transforms.Resample( downsample_resample = torchaudio.transforms.Resample(
sample_rate, downsample_rate, resampling_method='sinc_interpolation') sample_rate, downsample_rate, resampling_method="sinc_interpolation"
)
down_sampled = downsample_resample(waveform) down_sampled = downsample_resample(waveform)
# we expect the downsampled signal to have half as many samples # we expect the downsampled signal to have half as many samples
...@@ -289,9 +276,8 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -289,9 +276,8 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(computed_functional, computed_transform, atol=atol, rtol=rtol) self.assertEqual(computed_functional, computed_transform, atol=atol, rtol=rtol)
def test_compute_deltas_twochannel(self): def test_compute_deltas_twochannel(self):
specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1) specgram = torch.tensor([1.0, 2.0, 3.0, 4.0]).repeat(1, 2, 1)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]])
[0.5, 1.0, 1.0, 0.5]]])
transform = transforms.ComputeDeltas(win_length=3) transform = transforms.ComputeDeltas(win_length=3)
computed = transform(specgram) computed = transform(specgram)
assert computed.shape == expected.shape, (computed.shape, expected.shape) assert computed.shape == expected.shape, (computed.shape, expected.shape)
...@@ -299,7 +285,6 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -299,7 +285,6 @@ class Tester(common_utils.TorchaudioTestCase):
class SmokeTest(common_utils.TorchaudioTestCase): class SmokeTest(common_utils.TorchaudioTestCase):
def test_spectrogram(self): def test_spectrogram(self):
specgram = transforms.Spectrogram(center=False, pad_mode="reflect", onesided=False) specgram = transforms.Spectrogram(center=False, pad_mode="reflect", onesided=False)
self.assertEqual(specgram.center, False) self.assertEqual(specgram.center, False)
......
...@@ -38,15 +38,12 @@ class TransformsTestBase(TestBaseMixin): ...@@ -38,15 +38,12 @@ class TransformsTestBase(TestBaseMixin):
# Generate reference spectrogram and input mel-scaled spectrogram # Generate reference spectrogram and input mel-scaled spectrogram
expected = get_spectrogram( expected = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), n_fft=n_fft, power=power
n_fft=n_fft, power=power).to(self.device, self.dtype) ).to(self.device, self.dtype)
input = T.MelScale( input = T.MelScale(n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft).to(self.device, self.dtype)(expected)
n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft
).to(self.device, self.dtype)(expected)
# Run transform # Run transform
transform = T.InverseMelScale( transform = T.InverseMelScale(n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype)
n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype)
torch.random.manual_seed(0) torch.random.manual_seed(0)
result = transform(input) result = transform(input)
...@@ -55,9 +52,7 @@ class TransformsTestBase(TestBaseMixin): ...@@ -55,9 +52,7 @@ class TransformsTestBase(TestBaseMixin):
relative_diff = torch.abs((result - expected) / (expected + epsilon)) relative_diff = torch.abs((result - expected) / (expected + epsilon))
for tol in [1e-1, 1e-3, 1e-5, 1e-10]: for tol in [1e-1, 1e-3, 1e-5, 1e-10]:
print( print(f"Ratio of relative diff smaller than {tol:e} is " f"{_get_ratio(relative_diff < tol)}")
f"Ratio of relative diff smaller than {tol:e} is "
f"{_get_ratio(relative_diff < tol)}")
assert _get_ratio(relative_diff < 1e-1) > 0.2 assert _get_ratio(relative_diff < 1e-1) > 0.2
assert _get_ratio(relative_diff < 1e-3) > 5e-3 assert _get_ratio(relative_diff < 1e-3) > 5e-3
assert _get_ratio(relative_diff < 1e-5) > 1e-5 assert _get_ratio(relative_diff < 1e-5) > 1e-5
...@@ -84,21 +79,23 @@ class TransformsTestBase(TestBaseMixin): ...@@ -84,21 +79,23 @@ class TransformsTestBase(TestBaseMixin):
assert transform.kernel.dtype == dtype if dtype is not None else torch.float32 assert transform.kernel.dtype == dtype if dtype is not None else torch.float32
@parameterized.expand([ @parameterized.expand(
param(n_fft=300, center=True, onesided=True), [
param(n_fft=400, center=True, onesided=False), param(n_fft=300, center=True, onesided=True),
param(n_fft=400, center=True, onesided=False), param(n_fft=400, center=True, onesided=False),
param(n_fft=300, center=True, onesided=False), param(n_fft=400, center=True, onesided=False),
param(n_fft=400, hop_length=10), param(n_fft=300, center=True, onesided=False),
param(n_fft=800, win_length=400, hop_length=20), param(n_fft=400, hop_length=10),
param(n_fft=800, win_length=400, hop_length=20, normalized=True), param(n_fft=800, win_length=400, hop_length=20),
param(), param(n_fft=800, win_length=400, hop_length=20, normalized=True),
param(n_fft=400, pad=32), param(),
# These tests do not work - cause runtime error param(n_fft=400, pad=32),
# See https://github.com/pytorch/pytorch/issues/62323 # These tests do not work - cause runtime error
# param(n_fft=400, center=False, onesided=True), # See https://github.com/pytorch/pytorch/issues/62323
# param(n_fft=400, center=False, onesided=False), # param(n_fft=400, center=False, onesided=True),
]) # param(n_fft=400, center=False, onesided=False),
]
)
def test_roundtrip_spectrogram(self, **args): def test_roundtrip_spectrogram(self, **args):
"""Test the spectrogram + inverse spectrogram results in approximate identity.""" """Test the spectrogram + inverse spectrogram results in approximate identity."""
...@@ -110,12 +107,14 @@ class TransformsTestBase(TestBaseMixin): ...@@ -110,12 +107,14 @@ class TransformsTestBase(TestBaseMixin):
restored = inv_s.forward(transformed, length=waveform.shape[-1]) restored = inv_s.forward(transformed, length=waveform.shape[-1])
self.assertEqual(waveform, restored, atol=1e-6, rtol=1e-6) self.assertEqual(waveform, restored, atol=1e-6, rtol=1e-6)
@parameterized.expand([ @parameterized.expand(
param(0.5, 1, True, False), [
param(0.5, 1, None, False), param(0.5, 1, True, False),
param(1, 4, True, True), param(0.5, 1, None, False),
param(1, 6, None, True), param(1, 4, True, True),
]) param(1, 6, None, True),
]
)
def test_psd(self, duration, channel, mask, multi_mask): def test_psd(self, duration, channel, mask, multi_mask):
"""Providing dtype changes the kernel cache dtype""" """Providing dtype changes the kernel cache dtype"""
transform = T.PSD(multi_mask) transform = T.PSD(multi_mask)
......
from torchaudio.utils import sox_utils from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
PytorchTestCase, PytorchTestCase,
skipIfNoSox, skipIfNoSox,
...@@ -9,6 +8,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -9,6 +8,7 @@ from torchaudio_unittest.common_utils import (
@skipIfNoSox @skipIfNoSox
class TestSoxUtils(PytorchTestCase): class TestSoxUtils(PytorchTestCase):
"""Smoke tests for sox_util module""" """Smoke tests for sox_util module"""
def test_set_seed(self): def test_set_seed(self):
"""`set_seed` does not crush""" """`set_seed` does not crush"""
sox_utils.set_seed(0) sox_utils.set_seed(0)
...@@ -34,16 +34,16 @@ class TestSoxUtils(PytorchTestCase): ...@@ -34,16 +34,16 @@ class TestSoxUtils(PytorchTestCase):
"""`list_effects` returns the list of available effects""" """`list_effects` returns the list of available effects"""
effects = sox_utils.list_effects() effects = sox_utils.list_effects()
# We cannot infer what effects are available, so only check some of them. # We cannot infer what effects are available, so only check some of them.
assert 'highpass' in effects assert "highpass" in effects
assert 'phaser' in effects assert "phaser" in effects
assert 'gain' in effects assert "gain" in effects
def test_list_read_formats(self): def test_list_read_formats(self):
"""`list_read_formats` returns the list of supported formats""" """`list_read_formats` returns the list of supported formats"""
formats = sox_utils.list_read_formats() formats = sox_utils.list_read_formats()
assert 'wav' in formats assert "wav" in formats
def test_list_write_formats(self): def test_list_write_formats(self):
"""`list_write_formats` returns the list of supported formats""" """`list_write_formats` returns the list of supported formats"""
formats = sox_utils.list_write_formats() formats = sox_utils.list_write_formats()
assert 'opus' not in formats assert "opus" not in formats
...@@ -35,21 +35,15 @@ def _parse_args(): ...@@ -35,21 +35,15 @@ def _parse_args():
description=__doc__, description=__doc__,
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
parser.add_argument("--input-file", required=True, help="Input model file.")
parser.add_argument("--output-file", required=False, help="Output model file.")
parser.add_argument( parser.add_argument(
'--input-file', required=True, "--dict-dir",
help='Input model file.'
)
parser.add_argument(
'--output-file', required=False,
help='Output model file.'
)
parser.add_argument(
'--dict-dir',
help=( help=(
'Directory where letter vocabulary file, `dict.ltr.txt`, is found. ' "Directory where letter vocabulary file, `dict.ltr.txt`, is found. "
'Required when loading wav2vec2 model. ' "Required when loading wav2vec2 model. "
'https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt' "https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt"
) ),
) )
return parser.parse_args() return parser.parse_args()
...@@ -57,9 +51,10 @@ def _parse_args(): ...@@ -57,9 +51,10 @@ def _parse_args():
def _load_model(input_file, dict_dir): def _load_model(input_file, dict_dir):
import fairseq import fairseq
overrides = {} if dict_dir is None else {'data': dict_dir} overrides = {} if dict_dir is None else {"data": dict_dir}
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[input_file], arg_overrides=overrides, [input_file],
arg_overrides=overrides,
) )
return models[0] return models[0]
...@@ -67,7 +62,7 @@ def _load_model(input_file, dict_dir): ...@@ -67,7 +62,7 @@ def _load_model(input_file, dict_dir):
def _import_model(model): def _import_model(model):
from torchaudio.models.wav2vec2.utils import import_fairseq_model from torchaudio.models.wav2vec2.utils import import_fairseq_model
if model.__class__.__name__ in ['HubertCtc', 'Wav2VecCtc']: if model.__class__.__name__ in ["HubertCtc", "Wav2VecCtc"]:
model = model.w2v_encoder model = model.w2v_encoder
model = import_fairseq_model(model) model = import_fairseq_model(model)
return model return model
...@@ -75,10 +70,11 @@ def _import_model(model): ...@@ -75,10 +70,11 @@ def _import_model(model):
def _main(args): def _main(args):
import torch import torch
model = _load_model(args.input_file, args.dict_dir) model = _load_model(args.input_file, args.dict_dir)
model = _import_model(model) model = _import_model(model)
torch.save(model.state_dict(), args.output_file) torch.save(model.state_dict(), args.output_file)
if __name__ == '__main__': if __name__ == "__main__":
_main(_parse_args()) _main(_parse_args())
...@@ -19,24 +19,19 @@ python convert_voxpopuli_models.py \ ...@@ -19,24 +19,19 @@ python convert_voxpopuli_models.py \
def _parse_args(): def _parse_args():
import argparse import argparse
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, description=__doc__,
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
) )
parser.add_argument( parser.add_argument("--input-file", required=True, help="Input checkpoint file.")
'--input-file', required=True, parser.add_argument("--output-file", required=False, help="Output model file.")
help='Input checkpoint file.'
)
parser.add_argument(
'--output-file', required=False,
help='Output model file.'
)
return parser.parse_args() return parser.parse_args()
def _removeprefix(s, prefix): def _removeprefix(s, prefix):
if s.startswith(prefix): if s.startswith(prefix):
return s[len(prefix):] return s[len(prefix) :]
return s return s
...@@ -45,13 +40,13 @@ def _load(input_file): ...@@ -45,13 +40,13 @@ def _load(input_file):
from omegaconf import OmegaConf from omegaconf import OmegaConf
data = torch.load(input_file) data = torch.load(input_file)
cfg = OmegaConf.to_container(data['cfg']) cfg = OmegaConf.to_container(data["cfg"])
for key in list(cfg.keys()): for key in list(cfg.keys()):
if key != 'model': if key != "model":
del cfg[key] del cfg[key]
if 'w2v_args' in cfg['model']: if "w2v_args" in cfg["model"]:
del cfg['model']['w2v_args'][key] del cfg["model"]["w2v_args"][key]
state_dict = {_removeprefix(k, 'w2v_encoder.'): v for k, v in data['model'].items()} state_dict = {_removeprefix(k, "w2v_encoder."): v for k, v in data["model"].items()}
return cfg, state_dict return cfg, state_dict
...@@ -75,9 +70,9 @@ def _parse_model_param(cfg, state_dict): ...@@ -75,9 +70,9 @@ def _parse_model_param(cfg, state_dict):
"encoder_layerdrop": "encoder_layer_drop", "encoder_layerdrop": "encoder_layer_drop",
} }
params = {} params = {}
src_dicts = [cfg['model']] src_dicts = [cfg["model"]]
if 'w2v_args' in cfg['model']: if "w2v_args" in cfg["model"]:
src_dicts.append(cfg['model']['w2v_args']['model']) src_dicts.append(cfg["model"]["w2v_args"]["model"])
for src, tgt in key_mapping.items(): for src, tgt in key_mapping.items():
for model_cfg in src_dicts: for model_cfg in src_dicts:
...@@ -89,12 +84,13 @@ def _parse_model_param(cfg, state_dict): ...@@ -89,12 +84,13 @@ def _parse_model_param(cfg, state_dict):
# the following line is commented out to resolve lint warning; uncomment before running script # the following line is commented out to resolve lint warning; uncomment before running script
# params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"]) # params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"])
assert len(params) == 15 assert len(params) == 15
params['aux_num_out'] = state_dict['proj.bias'].numel() if 'proj.bias' in state_dict else None params["aux_num_out"] = state_dict["proj.bias"].numel() if "proj.bias" in state_dict else None
return params return params
def _main(args): def _main(args):
import json import json
import torch import torch
import torchaudio import torchaudio
from torchaudio.models.wav2vec2.utils.import_fairseq import _convert_state_dict as _convert from torchaudio.models.wav2vec2.utils.import_fairseq import _convert_state_dict as _convert
...@@ -107,5 +103,5 @@ def _main(args): ...@@ -107,5 +103,5 @@ def _main(args):
torch.save(model.state_dict(), args.output_file) torch.save(model.state_dict(), args.output_file)
if __name__ == '__main__': if __name__ == "__main__":
_main(_parse_args()) _main(_parse_args())
...@@ -24,11 +24,11 @@ Features = namedtuple( ...@@ -24,11 +24,11 @@ Features = namedtuple(
def _run_cmd(cmd): def _run_cmd(cmd):
return subprocess.check_output(cmd).decode('utf-8').strip() return subprocess.check_output(cmd).decode("utf-8").strip()
def commit_title(commit_hash): def commit_title(commit_hash):
cmd = ['git', 'log', '-n', '1', '--pretty=format:%s', f'{commit_hash}'] cmd = ["git", "log", "-n", "1", "--pretty=format:%s", f"{commit_hash}"]
return _run_cmd(cmd) return _run_cmd(cmd)
...@@ -95,12 +95,12 @@ def get_features(commit_hash): ...@@ -95,12 +95,12 @@ def get_features(commit_hash):
def get_commits_between(base_version, new_version): def get_commits_between(base_version, new_version):
cmd = ['git', 'merge-base', f'{base_version}', f'{new_version}'] cmd = ["git", "merge-base", f"{base_version}", f"{new_version}"]
merge_base = _run_cmd(cmd) merge_base = _run_cmd(cmd)
# Returns a list of items in the form # Returns a list of items in the form
# a7854f33 Add HuBERT model architectures (#1769) # a7854f33 Add HuBERT model architectures (#1769)
cmd = ['git', 'log', '--reverse', '--oneline', f'{merge_base}..{new_version}'] cmd = ["git", "log", "--reverse", "--oneline", f"{merge_base}..{new_version}"]
commits = _run_cmd(cmd) commits = _run_cmd(cmd)
log_lines = commits.split("\n") log_lines = commits.split("\n")
......
import distutils.sysconfig
import os import os
import platform import platform
import subprocess import subprocess
from pathlib import Path from pathlib import Path
import distutils.sysconfig
import torch
from setuptools import Extension from setuptools import Extension
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
import torch
__all__ = [ __all__ = [
'get_ext_modules', "get_ext_modules",
'CMakeBuild', "CMakeBuild",
] ]
_THIS_DIR = Path(__file__).parent.resolve() _THIS_DIR = Path(__file__).parent.resolve()
_ROOT_DIR = _THIS_DIR.parent.parent.resolve() _ROOT_DIR = _THIS_DIR.parent.parent.resolve()
_TORCHAUDIO_DIR = _ROOT_DIR / 'torchaudio' _TORCHAUDIO_DIR = _ROOT_DIR / "torchaudio"
def _get_build(var, default=False): def _get_build(var, default=False):
if var not in os.environ: if var not in os.environ:
return default return default
val = os.environ.get(var, '0') val = os.environ.get(var, "0")
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES'] trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"]
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO'] falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"]
if val in trues: if val in trues:
return True return True
if val not in falses: if val not in falses:
print( print(f"WARNING: Unexpected environment variable value `{var}={val}`. " f"Expected one of {trues + falses}")
f'WARNING: Unexpected environment variable value `{var}={val}`. '
f'Expected one of {trues + falses}')
return False return False
_BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX", True) _BUILD_SOX = False if platform.system() == "Windows" else _get_build("BUILD_SOX", True)
_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True) _BUILD_KALDI = False if platform.system() == "Windows" else _get_build("BUILD_KALDI", True)
_BUILD_RNNT = _get_build("BUILD_RNNT", True) _BUILD_RNNT = _get_build("BUILD_RNNT", True)
_BUILD_CTC_DECODER = False if platform.system() == 'Windows' else _get_build("BUILD_CTC_DECODER", True) _BUILD_CTC_DECODER = False if platform.system() == "Windows" else _get_build("BUILD_CTC_DECODER", True)
_USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None) _USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None) _USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None)
_USE_OPENMP = _get_build("USE_OPENMP", True) and \ _USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info()
'ATen parallel backend: OpenMP' in torch.__config__.parallel_info() _TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
_TORCH_CUDA_ARCH_LIST = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
def get_ext_modules(): def get_ext_modules():
modules = [ modules = [
Extension(name='torchaudio.lib.libtorchaudio', sources=[]), Extension(name="torchaudio.lib.libtorchaudio", sources=[]),
Extension(name='torchaudio._torchaudio', sources=[]), Extension(name="torchaudio._torchaudio", sources=[]),
] ]
if _BUILD_CTC_DECODER: if _BUILD_CTC_DECODER:
modules.extend([ modules.extend(
Extension(name='torchaudio.lib.libtorchaudio_decoder', sources=[]), [
Extension(name='torchaudio._torchaudio_decoder', sources=[]), Extension(name="torchaudio.lib.libtorchaudio_decoder", sources=[]),
]) Extension(name="torchaudio._torchaudio_decoder", sources=[]),
]
)
return modules return modules
...@@ -63,7 +62,7 @@ def get_ext_modules(): ...@@ -63,7 +62,7 @@ def get_ext_modules():
class CMakeBuild(build_ext): class CMakeBuild(build_ext):
def run(self): def run(self):
try: try:
subprocess.check_output(['cmake', '--version']) subprocess.check_output(["cmake", "--version"])
except OSError: except OSError:
raise RuntimeError("CMake is not available.") from None raise RuntimeError("CMake is not available.") from None
super().run() super().run()
...@@ -75,11 +74,10 @@ class CMakeBuild(build_ext): ...@@ -75,11 +74,10 @@ class CMakeBuild(build_ext):
# However, the following `cmake` command will build all of them at the same time, # However, the following `cmake` command will build all of them at the same time,
# so, we do not need to perform `cmake` twice. # so, we do not need to perform `cmake` twice.
# Therefore we call `cmake` only for `torchaudio._torchaudio`. # Therefore we call `cmake` only for `torchaudio._torchaudio`.
if ext.name != 'torchaudio._torchaudio': if ext.name != "torchaudio._torchaudio":
return return
extdir = os.path.abspath( extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
os.path.dirname(self.get_ext_fullpath(ext.name)))
# required for auto-detection of auxiliary "native" libs # required for auto-detection of auxiliary "native" libs
if not extdir.endswith(os.path.sep): if not extdir.endswith(os.path.sep):
...@@ -91,7 +89,7 @@ class CMakeBuild(build_ext): ...@@ -91,7 +89,7 @@ class CMakeBuild(build_ext):
f"-DCMAKE_BUILD_TYPE={cfg}", f"-DCMAKE_BUILD_TYPE={cfg}",
f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}", f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}",
f"-DCMAKE_INSTALL_PREFIX={extdir}", f"-DCMAKE_INSTALL_PREFIX={extdir}",
'-DCMAKE_VERBOSE_MAKEFILE=ON', "-DCMAKE_VERBOSE_MAKEFILE=ON",
f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}", f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}", f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}", f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
...@@ -102,22 +100,21 @@ class CMakeBuild(build_ext): ...@@ -102,22 +100,21 @@ class CMakeBuild(build_ext):
f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}", f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}",
f"-DUSE_OPENMP:BOOL={'ON' if _USE_OPENMP else 'OFF'}", f"-DUSE_OPENMP:BOOL={'ON' if _USE_OPENMP else 'OFF'}",
] ]
build_args = [ build_args = ["--target", "install"]
'--target', 'install'
]
# Pass CUDA architecture to cmake # Pass CUDA architecture to cmake
if _TORCH_CUDA_ARCH_LIST is not None: if _TORCH_CUDA_ARCH_LIST is not None:
# Convert MAJOR.MINOR[+PTX] list to new style one # Convert MAJOR.MINOR[+PTX] list to new style one
# defined at https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html # defined at https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html
_arches = _TORCH_CUDA_ARCH_LIST.replace('.', '').replace(' ', ';').split(";") _arches = _TORCH_CUDA_ARCH_LIST.replace(".", "").replace(" ", ";").split(";")
_arches = [arch[:-4] if arch.endswith("+PTX") else f"{arch}-real" for arch in _arches] _arches = [arch[:-4] if arch.endswith("+PTX") else f"{arch}-real" for arch in _arches]
cmake_args += [f"-DCMAKE_CUDA_ARCHITECTURES={';'.join(_arches)}"] cmake_args += [f"-DCMAKE_CUDA_ARCHITECTURES={';'.join(_arches)}"]
# Default to Ninja # Default to Ninja
if 'CMAKE_GENERATOR' not in os.environ or platform.system() == 'Windows': if "CMAKE_GENERATOR" not in os.environ or platform.system() == "Windows":
cmake_args += ["-GNinja"] cmake_args += ["-GNinja"]
if platform.system() == 'Windows': if platform.system() == "Windows":
import sys import sys
python_version = sys.version_info python_version = sys.version_info
cmake_args += [ cmake_args += [
"-DCMAKE_C_COMPILER=cl", "-DCMAKE_C_COMPILER=cl",
...@@ -137,14 +134,12 @@ class CMakeBuild(build_ext): ...@@ -137,14 +134,12 @@ class CMakeBuild(build_ext):
if not os.path.exists(self.build_temp): if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp) os.makedirs(self.build_temp)
subprocess.check_call( subprocess.check_call(["cmake", str(_ROOT_DIR)] + cmake_args, cwd=self.build_temp)
["cmake", str(_ROOT_DIR)] + cmake_args, cwd=self.build_temp) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
subprocess.check_call(
["cmake", "--build", "."] + build_args, cwd=self.build_temp)
def get_ext_filename(self, fullname): def get_ext_filename(self, fullname):
ext_filename = super().get_ext_filename(fullname) ext_filename = super().get_ext_filename(fullname)
ext_filename_parts = ext_filename.split('.') ext_filename_parts = ext_filename.split(".")
without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:] without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
ext_filename = '.'.join(without_abi) ext_filename = ".".join(without_abi)
return ext_filename return ext_filename
...@@ -10,7 +10,6 @@ from torchaudio import ( ...@@ -10,7 +10,6 @@ from torchaudio import (
sox_effects, sox_effects,
transforms, transforms,
) )
from torchaudio.backend import ( from torchaudio.backend import (
list_audio_backends, list_audio_backends,
get_audio_backend, get_audio_backend,
...@@ -23,16 +22,16 @@ except ImportError: ...@@ -23,16 +22,16 @@ except ImportError:
pass pass
__all__ = [ __all__ = [
'compliance', "compliance",
'datasets', "datasets",
'functional', "functional",
'models', "models",
'pipelines', "pipelines",
'kaldi_io', "kaldi_io",
'utils', "utils",
'sox_effects', "sox_effects",
'transforms', "transforms",
'list_audio_backends', "list_audio_backends",
'get_audio_backend', "get_audio_backend",
'set_audio_backend', "set_audio_backend",
] ]
...@@ -5,12 +5,12 @@ from pathlib import Path ...@@ -5,12 +5,12 @@ from pathlib import Path
import torch import torch
from torchaudio._internal import module_utils as _mod_utils # noqa: F401 from torchaudio._internal import module_utils as _mod_utils # noqa: F401
_LIB_DIR = Path(__file__).parent / 'lib' _LIB_DIR = Path(__file__).parent / "lib"
def _get_lib_path(lib: str): def _get_lib_path(lib: str):
suffix = 'pyd' if os.name == 'nt' else 'so' suffix = "pyd" if os.name == "nt" else "so"
path = _LIB_DIR / f'{lib}.{suffix}' path = _LIB_DIR / f"{lib}.{suffix}"
return path return path
...@@ -26,11 +26,11 @@ def _load_lib(lib: str): ...@@ -26,11 +26,11 @@ def _load_lib(lib: str):
def _init_extension(): def _init_extension():
if not _mod_utils.is_module_available('torchaudio._torchaudio'): if not _mod_utils.is_module_available("torchaudio._torchaudio"):
warnings.warn('torchaudio C++ extension is not available.') warnings.warn("torchaudio C++ extension is not available.")
return return
_load_lib('libtorchaudio') _load_lib("libtorchaudio")
# This import is for initializing the methods registered via PyBind11 # This import is for initializing the methods registered via PyBind11
# This has to happen after the base library is loaded # This has to happen after the base library is loaded
from torchaudio import _torchaudio # noqa from torchaudio import _torchaudio # noqa
......
import warnings
import importlib.util import importlib.util
from typing import Optional import warnings
from functools import wraps from functools import wraps
from typing import Optional
import torch import torch
...@@ -28,14 +28,17 @@ def requires_module(*modules: str): ...@@ -28,14 +28,17 @@ def requires_module(*modules: str):
# fall through. If all the modules are available, no need to decorate # fall through. If all the modules are available, no need to decorate
def decorator(func): def decorator(func):
return func return func
else: else:
req = f'module: {missing[0]}' if len(missing) == 1 else f'modules: {missing}' req = f"module: {missing[0]}" if len(missing) == 1 else f"modules: {missing}"
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires {req}') raise RuntimeError(f"{func.__module__}.{func.__name__} requires {req}")
return wrapped return wrapped
return decorator return decorator
...@@ -46,42 +49,51 @@ def deprecated(direction: str, version: Optional[str] = None): ...@@ -46,42 +49,51 @@ def deprecated(direction: str, version: Optional[str] = None):
direction (str): Migration steps to be given to users. direction (str): Migration steps to be given to users.
version (str or int): The version when the object will be removed version (str or int): The version when the object will be removed
""" """
def decorator(func):
def decorator(func):
@wraps(func) @wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
message = ( message = (
f'{func.__module__}.{func.__name__} has been deprecated ' f"{func.__module__}.{func.__name__} has been deprecated "
f'and will be removed from {"future" if version is None else version} release. ' f'and will be removed from {"future" if version is None else version} release. '
f'{direction}') f"{direction}"
)
warnings.warn(message, stacklevel=2) warnings.warn(message, stacklevel=2)
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapped return wrapped
return decorator return decorator
def is_kaldi_available(): def is_kaldi_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_kaldi_available() return is_module_available("torchaudio._torchaudio") and torch.ops.torchaudio.is_kaldi_available()
def requires_kaldi(): def requires_kaldi():
if is_kaldi_available(): if is_kaldi_available():
def decorator(func): def decorator(func):
return func return func
else: else:
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires kaldi') raise RuntimeError(f"{func.__module__}.{func.__name__} requires kaldi")
return wrapped return wrapped
return decorator return decorator
def _check_soundfile_importable(): def _check_soundfile_importable():
if not is_module_available('soundfile'): if not is_module_available("soundfile"):
return False return False
try: try:
import soundfile # noqa: F401 import soundfile # noqa: F401
return True return True
except Exception: except Exception:
warnings.warn("Failed to import soundfile. 'soundfile' backend is not available.") warnings.warn("Failed to import soundfile. 'soundfile' backend is not available.")
...@@ -97,29 +109,39 @@ def is_soundfile_available(): ...@@ -97,29 +109,39 @@ def is_soundfile_available():
def requires_soundfile(): def requires_soundfile():
if is_soundfile_available(): if is_soundfile_available():
def decorator(func): def decorator(func):
return func return func
else: else:
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires soundfile') raise RuntimeError(f"{func.__module__}.{func.__name__} requires soundfile")
return wrapped return wrapped
return decorator return decorator
def is_sox_available(): def is_sox_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available() return is_module_available("torchaudio._torchaudio") and torch.ops.torchaudio.is_sox_available()
def requires_sox(): def requires_sox():
if is_sox_available(): if is_sox_available():
def decorator(func): def decorator(func):
return func return func
else: else:
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires sox') raise RuntimeError(f"{func.__module__}.{func.__name__} requires sox")
return wrapped return wrapped
return decorator return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment