Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
"""Test suites for checking numerical compatibility against Kaldi"""
import torchaudio.compliance.kaldi
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TestBaseMixin,
TempDirMixin,
......@@ -21,35 +20,35 @@ class Kaldi(TempDirMixin, TestBaseMixin):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
@parameterized.expand(load_params('kaldi_test_fbank_args.jsonl'))
@skipIfNoExec('compute-fbank-feats')
@parameterized.expand(load_params("kaldi_test_fbank_args.jsonl"))
@skipIfNoExec("compute-fbank-feats")
def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats"""
wave_file = get_asset_path('kaldi_file.wav')
wave_file = get_asset_path("kaldi_file.wav")
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
command = ['compute-fbank-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = run_kaldi(command, 'scp', wave_file)
command = ["compute-fbank-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_spectrogram_args.jsonl'))
@skipIfNoExec('compute-spectrogram-feats')
@parameterized.expand(load_params("kaldi_test_spectrogram_args.jsonl"))
@skipIfNoExec("compute-spectrogram-feats")
def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file = get_asset_path('kaldi_file.wav')
wave_file = get_asset_path("kaldi_file.wav")
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
command = ['compute-spectrogram-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = run_kaldi(command, 'scp', wave_file)
command = ["compute-spectrogram-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_mfcc_args.jsonl'))
@skipIfNoExec('compute-mfcc-feats')
@parameterized.expand(load_params("kaldi_test_mfcc_args.jsonl"))
@skipIfNoExec("compute-mfcc-feats")
def test_mfcc(self, kwargs):
"""mfcc should be numerically compatible with compute-mfcc-feats"""
wave_file = get_asset_path('kaldi_file.wav')
wave_file = get_asset_path("kaldi_file.wav")
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
command = ['compute-mfcc-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = run_kaldi(command, 'scp', wave_file)
command = ["compute-mfcc-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .librosa_compatibility_test_impl import TransformsTestBase
class TestTransforms(TransformsTestBase, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .librosa_compatibility_test_impl import TransformsTestBase
@skipIfNoCuda
class TestTransforms(TransformsTestBase, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
device = torch.device("cuda")
......@@ -2,9 +2,8 @@ import unittest
import torch
import torchaudio.transforms as T
from torchaudio._internal.module_utils import is_module_available
from parameterized import param, parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
......@@ -13,7 +12,7 @@ from torchaudio_unittest.common_utils import (
nested_params,
)
LIBROSA_AVAILABLE = is_module_available('librosa')
LIBROSA_AVAILABLE = is_module_available("librosa")
if LIBROSA_AVAILABLE:
import librosa
......@@ -21,25 +20,28 @@ if LIBROSA_AVAILABLE:
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TransformsTestBase(TestBaseMixin):
@parameterized.expand([
param(n_fft=400, hop_length=200, power=2.0),
param(n_fft=600, hop_length=100, power=2.0),
param(n_fft=400, hop_length=200, power=3.0),
param(n_fft=200, hop_length=50, power=2.0),
])
@parameterized.expand(
[
param(n_fft=400, hop_length=200, power=2.0),
param(n_fft=600, hop_length=100, power=2.0),
param(n_fft=400, hop_length=200, power=3.0),
param(n_fft=200, hop_length=50, power=2.0),
]
)
def test_Spectrogram(self, n_fft, hop_length, power):
sample_rate = 16000
waveform = get_whitenoise(
sample_rate=sample_rate, n_channels=1,
sample_rate=sample_rate,
n_channels=1,
).to(self.device, self.dtype)
expected = librosa.core.spectrum._spectrogram(
y=waveform[0].cpu().numpy(),
n_fft=n_fft, hop_length=hop_length, power=power)[0]
y=waveform[0].cpu().numpy(), n_fft=n_fft, hop_length=hop_length, power=power
)[0]
result = T.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=power,
).to(self.device, self.dtype)(waveform)[0]
result = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=power,).to(self.device, self.dtype)(
waveform
)[0]
self.assertEqual(result, torch.from_numpy(expected), atol=1e-5, rtol=1e-5)
def test_Spectrogram_complex(self):
......@@ -47,16 +49,17 @@ class TransformsTestBase(TestBaseMixin):
hop_length = 200
sample_rate = 16000
waveform = get_whitenoise(
sample_rate=sample_rate, n_channels=1,
sample_rate=sample_rate,
n_channels=1,
).to(self.device, self.dtype)
expected = librosa.core.spectrum._spectrogram(
y=waveform[0].cpu().numpy(),
n_fft=n_fft, hop_length=hop_length, power=1)[0]
y=waveform[0].cpu().numpy(), n_fft=n_fft, hop_length=hop_length, power=1
)[0]
result = T.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True,
).to(self.device, self.dtype)(waveform)[0]
result = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True,).to(
self.device, self.dtype
)(waveform)[0]
self.assertEqual(result.abs(), torch.from_numpy(expected), atol=1e-5, rtol=1e-5)
@nested_params(
......@@ -65,77 +68,95 @@ class TransformsTestBase(TestBaseMixin):
param(n_fft=600, hop_length=100, n_mels=128),
param(n_fft=200, hop_length=50, n_mels=32),
],
[param(norm=norm) for norm in [None, 'slaney']],
[param(mel_scale=mel_scale) for mel_scale in ['htk', 'slaney']],
[param(norm=norm) for norm in [None, "slaney"]],
[param(mel_scale=mel_scale) for mel_scale in ["htk", "slaney"]],
)
def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale):
sample_rate = 16000
waveform = get_sinusoid(
sample_rate=sample_rate, n_channels=1,
sample_rate=sample_rate,
n_channels=1,
).to(self.device, self.dtype)
expected = librosa.feature.melspectrogram(
y=waveform[0].cpu().numpy(),
sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, norm=norm,
htk=mel_scale == "htk")
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
norm=norm,
htk=mel_scale == "htk",
)
result = T.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels,
n_fft=n_fft, norm=norm, mel_scale=mel_scale,
sample_rate=sample_rate,
window_fn=torch.hann_window,
hop_length=hop_length,
n_mels=n_mels,
n_fft=n_fft,
norm=norm,
mel_scale=mel_scale,
).to(self.device, self.dtype)(waveform)[0]
self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5)
def test_magnitude_to_db(self):
spectrogram = get_spectrogram(
get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype)
result = T.AmplitudeToDB('magnitude', 80.).to(self.device, self.dtype)(spectrogram)[0]
spectrogram = get_spectrogram(get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype)
result = T.AmplitudeToDB("magnitude", 80.0).to(self.device, self.dtype)(spectrogram)[0]
expected = librosa.core.spectrum.amplitude_to_db(spectrogram[0].cpu().numpy())
self.assertEqual(result, torch.from_numpy(expected))
def test_power_to_db(self):
spectrogram = get_spectrogram(
get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype)
result = T.AmplitudeToDB('power', 80.).to(self.device, self.dtype)(spectrogram)[0]
spectrogram = get_spectrogram(get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype)
result = T.AmplitudeToDB("power", 80.0).to(self.device, self.dtype)(spectrogram)[0]
expected = librosa.core.spectrum.power_to_db(spectrogram[0].cpu().numpy())
self.assertEqual(result, torch.from_numpy(expected))
@nested_params([
param(n_fft=400, hop_length=200, n_mels=64, n_mfcc=40),
param(n_fft=600, hop_length=100, n_mels=128, n_mfcc=20),
param(n_fft=200, hop_length=50, n_mels=32, n_mfcc=25),
])
@nested_params(
[
param(n_fft=400, hop_length=200, n_mels=64, n_mfcc=40),
param(n_fft=600, hop_length=100, n_mels=128, n_mfcc=20),
param(n_fft=200, hop_length=50, n_mels=32, n_mfcc=25),
]
)
def test_mfcc(self, n_fft, hop_length, n_mels, n_mfcc):
sample_rate = 16000
waveform = get_whitenoise(
sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype)
waveform = get_whitenoise(sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype)
result = T.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho',
melkwargs={'hop_length': hop_length, 'n_fft': n_fft, 'n_mels': n_mels},
sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm="ortho",
melkwargs={"hop_length": hop_length, "n_fft": n_fft, "n_mels": n_mels},
).to(self.device, self.dtype)(waveform)[0]
melspec = librosa.feature.melspectrogram(
y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft,
win_length=n_fft, hop_length=hop_length,
n_mels=n_mels, htk=True, norm=None)
y=waveform[0].cpu().numpy(),
sr=sample_rate,
n_fft=n_fft,
win_length=n_fft,
hop_length=hop_length,
n_mels=n_mels,
htk=True,
norm=None,
)
expected = librosa.feature.mfcc(
S=librosa.core.spectrum.power_to_db(melspec),
n_mfcc=n_mfcc, dct_type=2, norm='ortho')
S=librosa.core.spectrum.power_to_db(melspec), n_mfcc=n_mfcc, dct_type=2, norm="ortho"
)
self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5)
@parameterized.expand([
param(n_fft=400, hop_length=200),
param(n_fft=600, hop_length=100),
param(n_fft=200, hop_length=50),
])
@parameterized.expand(
[
param(n_fft=400, hop_length=200),
param(n_fft=600, hop_length=100),
param(n_fft=200, hop_length=50),
]
)
def test_spectral_centroid(self, n_fft, hop_length):
sample_rate = 16000
waveform = get_whitenoise(
sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype)
waveform = get_whitenoise(sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype)
result = T.SpectralCentroid(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
).to(self.device, self.dtype)(waveform)
result = T.SpectralCentroid(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,).to(
self.device, self.dtype
)(waveform)
expected = librosa.feature.spectral_centroid(
y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft, hop_length=hop_length)
y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft, hop_length=hop_length
)
self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5)
......@@ -3,7 +3,6 @@ import warnings
import torch
import torchaudio.transforms as T
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
skipIfNoSox,
skipIfNoExec,
......@@ -18,10 +17,10 @@ from torchaudio_unittest.common_utils import (
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoExec("sox")
class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def run_sox_effect(self, input_file, effect):
output_file = self.get_temp_path('expected.wav')
output_file = self.get_temp_path("expected.wav")
sox_utils.run_sox_effect(input_file, output_file, [str(e) for e in effect])
return load_wav(output_file)
......@@ -31,39 +30,45 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
def get_whitenoise(self, sample_rate=8000):
noise = get_whitenoise(
sample_rate=sample_rate, duration=3, scale_factor=0.9,
sample_rate=sample_rate,
duration=3,
scale_factor=0.9,
)
path = self.get_temp_path("whitenoise.wav")
save_wav(path, noise, sample_rate)
return noise, path
@parameterized.expand([
('q', 'quarter_sine'),
('h', 'half_sine'),
('t', 'linear'),
])
@parameterized.expand(
[
("q", "quarter_sine"),
("h", "half_sine"),
("t", "linear"),
]
)
def test_fade(self, fade_shape_sox, fade_shape):
fade_in_len, fade_out_len = 44100, 44100
data, path = self.get_whitenoise(sample_rate=44100)
result = T.Fade(fade_in_len, fade_out_len, fade_shape)(data)
self.assert_sox_effect(result, path, ['fade', fade_shape_sox, '1', '0', '1'])
self.assert_sox_effect(result, path, ["fade", fade_shape_sox, "1", "0", "1"])
@parameterized.expand([
('amplitude', 1.1),
('db', 2),
('power', 2),
])
@parameterized.expand(
[
("amplitude", 1.1),
("db", 2),
("power", 2),
]
)
def test_vol(self, gain_type, gain):
data, path = self.get_whitenoise()
result = T.Vol(gain, gain_type)(data)
self.assert_sox_effect(result, path, ['vol', f'{gain}', gain_type])
self.assert_sox_effect(result, path, ["vol", f"{gain}", gain_type])
@parameterized.expand(['vad-go-stereo-44100.wav', 'vad-go-mono-32000.wav'])
@parameterized.expand(["vad-go-stereo-44100.wav", "vad-go-mono-32000.wav"])
def test_vad(self, filename):
path = get_asset_path(filename)
data, sample_rate = load_wav(path)
result = T.Vad(sample_rate)(data)
self.assert_sox_effect(result, path, ['vad'])
self.assert_sox_effect(result, path, ["vad"])
def test_vad_warning(self):
"""vad should throw a warning if input dimension is greater than 2"""
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
device = torch.device("cpu")
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
@skipIfNoCuda
class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
device = torch.device("cuda")
@skipIfNoCuda
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
device = torch.device("cuda")
......@@ -3,7 +3,6 @@
import torch
import torchaudio.transforms as T
from parameterized import parameterized
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
skipIfRocm,
......@@ -14,6 +13,7 @@ from torchaudio_unittest.common_utils import (
class Transforms(TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency(self, transform, tensor, *args):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
......@@ -139,10 +139,7 @@ class Transforms(TestBaseMixin):
sample_rate = 8000
n_steps = 4
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(
T.PitchShift(sample_rate=sample_rate, n_steps=n_steps),
waveform
)
self._assert_consistency(T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), waveform)
def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
......@@ -157,33 +154,34 @@ class Transforms(TestBaseMixin):
mask = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.PSD(), spectrogram, mask)
@parameterized.expand([
["ref_channel", True],
["stv_evd", True],
["stv_power", True],
["ref_channel", False],
["stv_evd", False],
["stv_power", False],
])
@parameterized.expand(
[
["ref_channel", True],
["stv_evd", True],
["stv_power", True],
["ref_channel", False],
["stv_evd", False],
["stv_power", False],
]
)
def test_MVDR(self, solution, online):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
mask_s = torch.rand(spectrogram.shape[-2:], device=self.device)
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(
T.MVDR(solution=solution, online=online),
spectrogram, mask_s, mask_n
)
self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n)
class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
logits = torch.tensor(
[
[
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]],
]
]
)
tensor = logits.to(device=self.device, dtype=torch.float32)
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from . transforms_test_impl import TransformsTestBase
from .transforms_test_impl import TransformsTestBase
class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cpu'
device = "cpu"
dtype = torch.float32
class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cpu'
device = "cpu"
dtype = torch.float64
import torch
from torchaudio_unittest.common_utils import (
PytorchTestCase,
skipIfNoCuda,
)
from . transforms_test_impl import TransformsTestBase
from .transforms_test_impl import TransformsTestBase
@skipIfNoCuda
class TransformsCUDAFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cuda'
device = "cuda"
dtype = torch.float32
@skipIfNoCuda
class TransformsCUDAFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cuda'
device = "cuda"
dtype = torch.float64
......@@ -2,24 +2,23 @@ import math
import torch
import torchaudio
import torchaudio.transforms as transforms
import torchaudio.functional as F
import torchaudio.transforms as transforms
from torchaudio_unittest import common_utils
class Tester(common_utils.TorchaudioTestCase):
backend = 'default'
backend = "default"
# create a sinewave signal for testing
sample_rate = 16000
freq = 440
volume = .3
waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
volume = 0.3
waveform = torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)
waveform.unsqueeze_(0) # (1, 64000)
waveform = (waveform * volume * 2**31).long()
waveform = (waveform * volume * 2 ** 31).long()
def scale(self, waveform, factor=2.0**31):
def scale(self, waveform, factor=2.0 ** 31):
# scales a waveform by a factor
if not waveform.is_floating_point():
waveform = waveform.to(torch.get_default_dtype())
......@@ -34,20 +33,20 @@ class Tester(common_utils.TorchaudioTestCase):
waveform = waveform.to(torch.get_default_dtype())
waveform /= torch.abs(waveform).max()
self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
self.assertTrue(waveform.min() >= -1.0 and waveform.max() <= 1.0)
waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
self.assertTrue(waveform_mu.min() >= 0.0 and waveform_mu.max() <= quantization_channels)
waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
self.assertTrue(waveform_exp.min() >= -1.0 and waveform_exp.max() <= 1.0)
def test_AmplitudeToDB(self):
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
filepath = common_utils.get_asset_path("steam-train-whistle-daniel_simon.wav")
waveform = common_utils.load_wav(filepath)[0]
mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
power_to_db_transform = transforms.AmplitudeToDB('power', 80.)
mag_to_db_transform = transforms.AmplitudeToDB("magnitude", 80.0)
power_to_db_transform = transforms.AmplitudeToDB("power", 80.0)
mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))
......@@ -88,8 +87,8 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(fb, fb_copy)
def test_mel2(self):
top_db = 80.
s2db = transforms.AmplitudeToDB('power', top_db)
top_db = 80.0
s2db = transforms.AmplitudeToDB("power", top_db)
waveform = self.waveform.clone() # (1, 16000)
waveform_scaled = self.scale(waveform) # (1, 16000)
......@@ -100,20 +99,26 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
# check correctness of filterbank conversion matrix
self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.0).all())
self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.0).all())
# check options
kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
kwargs = {
"window_fn": torch.hamming_window,
"pad": 10,
"win_length": 500,
"hop_length": 125,
"n_fft": 800,
"n_mels": 50,
}
mel_transform2 = transforms.MelSpectrogram(**kwargs)
spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513)
self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.0).all())
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.0).all())
# check on multi-channel audio
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
filepath = common_utils.get_asset_path("steam-train-whistle-daniel_simon.wav")
x_stereo = common_utils.load_wav(filepath)[0] # (2, 278756), 44100
spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
self.assertTrue(spectrogram_stereo.dim() == 3)
......@@ -121,57 +126,46 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
# check filterbank matrix creation
fb_matrix_transform = transforms.MelScale(
n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
fb_matrix_transform = transforms.MelScale(n_mels=100, sample_rate=16000, f_min=0.0, f_max=None, n_stft=400)
self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.0).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.0).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_mfcc_defaults(self):
"""Check the default configuration of the MFCC transform.
"""
"""Check the default configuration of the MFCC transform."""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho')
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm="ortho")
torch_mfcc = mfcc_transform(audio) # (1, 40, 81)
self.assertEqual(torch_mfcc.dim(), 3)
self.assertEqual(torch_mfcc.shape[1], n_mfcc)
self.assertEqual(torch_mfcc.shape[2], 81)
def test_mfcc_kwargs_passthrough(self):
"""Check kwargs get correctly passed to the MelSpectrogram transform.
"""
"""Check kwargs get correctly passed to the MelSpectrogram transform."""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40
melkwargs = {'win_length': 200}
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
melkwargs = {"win_length": 200}
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm="ortho", melkwargs=melkwargs
)
torch_mfcc = mfcc_transform(audio) # (1, 40, 161)
self.assertEqual(torch_mfcc.shape[2], 161)
def test_mfcc_norms(self):
"""Check if MFCC-DCT norms work correctly.
"""
"""Check if MFCC-DCT norms work correctly."""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40
n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho')
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm="ortho")
# check norms work correctly
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm=None)
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm=None)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio) # (1, 40, 81)
norm_check = mfcc_transform(audio)
......@@ -181,56 +175,48 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(torch_mfcc_norm_none, norm_check)
def test_lfcc_defaults(self):
"""Check default settings for LFCC transform.
"""
"""Check default settings for LFCC transform."""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_lfcc = 40
n_filter = 128
lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate,
n_filter=n_filter,
n_lfcc=n_lfcc,
norm='ortho')
lfcc_transform = torchaudio.transforms.LFCC(
sample_rate=sample_rate, n_filter=n_filter, n_lfcc=n_lfcc, norm="ortho"
)
torch_lfcc = lfcc_transform(audio) # (1, 40, 81)
self.assertEqual(torch_lfcc.dim(), 3)
self.assertEqual(torch_lfcc.shape[1], n_lfcc)
self.assertEqual(torch_lfcc.shape[2], 81)
def test_lfcc_arg_passthrough(self):
"""Check if kwargs get correctly passed to the underlying Spectrogram transform.
"""
"""Check if kwargs get correctly passed to the underlying Spectrogram transform."""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_lfcc = 40
n_filter = 128
speckwargs = {'win_length': 200}
lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate,
n_filter=n_filter,
n_lfcc=n_lfcc,
norm='ortho',
speckwargs=speckwargs)
speckwargs = {"win_length": 200}
lfcc_transform = torchaudio.transforms.LFCC(
sample_rate=sample_rate, n_filter=n_filter, n_lfcc=n_lfcc, norm="ortho", speckwargs=speckwargs
)
torch_lfcc = lfcc_transform(audio) # (1, 40, 161)
self.assertEqual(torch_lfcc.shape[2], 161)
def test_lfcc_norms(self):
"""Check if LFCC-DCT norm works correctly.
"""
"""Check if LFCC-DCT norm works correctly."""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_lfcc = 40
n_filter = 128
lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate,
n_filter=n_filter,
n_lfcc=n_lfcc,
norm='ortho')
lfcc_transform_norm_none = torchaudio.transforms.LFCC(sample_rate=sample_rate,
n_filter=n_filter,
n_lfcc=n_lfcc,
norm=None)
lfcc_transform = torchaudio.transforms.LFCC(
sample_rate=sample_rate, n_filter=n_filter, n_lfcc=n_lfcc, norm="ortho"
)
lfcc_transform_norm_none = torchaudio.transforms.LFCC(
sample_rate=sample_rate, n_filter=n_filter, n_lfcc=n_lfcc, norm=None
)
torch_lfcc_norm_none = lfcc_transform_norm_none(audio) # (1, 40, 161)
norm_check = lfcc_transform(audio) # (1, 40, 161)
......@@ -240,26 +226,27 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(torch_lfcc_norm_none, norm_check)
def test_resample_size(self):
input_path = common_utils.get_asset_path('sinewave.wav')
input_path = common_utils.get_asset_path("sinewave.wav")
waveform, sample_rate = common_utils.load_wav(input_path)
upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2
invalid_resampling_method = 'foo'
invalid_resampling_method = "foo"
with self.assertRaises(ValueError):
torchaudio.transforms.Resample(sample_rate, upsample_rate,
resampling_method=invalid_resampling_method)
torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method=invalid_resampling_method)
upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method='sinc_interpolation')
sample_rate, upsample_rate, resampling_method="sinc_interpolation"
)
up_sampled = upsample_resample(waveform)
# we expect the upsampled signal to have twice as many samples
self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
downsample_resample = torchaudio.transforms.Resample(
sample_rate, downsample_rate, resampling_method='sinc_interpolation')
sample_rate, downsample_rate, resampling_method="sinc_interpolation"
)
down_sampled = downsample_resample(waveform)
# we expect the downsampled signal to have half as many samples
......@@ -289,9 +276,8 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertEqual(computed_functional, computed_transform, atol=atol, rtol=rtol)
def test_compute_deltas_twochannel(self):
specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
specgram = torch.tensor([1.0, 2.0, 3.0, 4.0]).repeat(1, 2, 1)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]])
transform = transforms.ComputeDeltas(win_length=3)
computed = transform(specgram)
assert computed.shape == expected.shape, (computed.shape, expected.shape)
......@@ -299,7 +285,6 @@ class Tester(common_utils.TorchaudioTestCase):
class SmokeTest(common_utils.TorchaudioTestCase):
def test_spectrogram(self):
specgram = transforms.Spectrogram(center=False, pad_mode="reflect", onesided=False)
self.assertEqual(specgram.center, False)
......
......@@ -38,15 +38,12 @@ class TransformsTestBase(TestBaseMixin):
# Generate reference spectrogram and input mel-scaled spectrogram
expected = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2),
n_fft=n_fft, power=power).to(self.device, self.dtype)
input = T.MelScale(
n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft
).to(self.device, self.dtype)(expected)
get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), n_fft=n_fft, power=power
).to(self.device, self.dtype)
input = T.MelScale(n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft).to(self.device, self.dtype)(expected)
# Run transform
transform = T.InverseMelScale(
n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype)
transform = T.InverseMelScale(n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype)
torch.random.manual_seed(0)
result = transform(input)
......@@ -55,9 +52,7 @@ class TransformsTestBase(TestBaseMixin):
relative_diff = torch.abs((result - expected) / (expected + epsilon))
for tol in [1e-1, 1e-3, 1e-5, 1e-10]:
print(
f"Ratio of relative diff smaller than {tol:e} is "
f"{_get_ratio(relative_diff < tol)}")
print(f"Ratio of relative diff smaller than {tol:e} is " f"{_get_ratio(relative_diff < tol)}")
assert _get_ratio(relative_diff < 1e-1) > 0.2
assert _get_ratio(relative_diff < 1e-3) > 5e-3
assert _get_ratio(relative_diff < 1e-5) > 1e-5
......@@ -84,21 +79,23 @@ class TransformsTestBase(TestBaseMixin):
assert transform.kernel.dtype == dtype if dtype is not None else torch.float32
@parameterized.expand([
param(n_fft=300, center=True, onesided=True),
param(n_fft=400, center=True, onesided=False),
param(n_fft=400, center=True, onesided=False),
param(n_fft=300, center=True, onesided=False),
param(n_fft=400, hop_length=10),
param(n_fft=800, win_length=400, hop_length=20),
param(n_fft=800, win_length=400, hop_length=20, normalized=True),
param(),
param(n_fft=400, pad=32),
# These tests do not work - cause runtime error
# See https://github.com/pytorch/pytorch/issues/62323
# param(n_fft=400, center=False, onesided=True),
# param(n_fft=400, center=False, onesided=False),
])
@parameterized.expand(
[
param(n_fft=300, center=True, onesided=True),
param(n_fft=400, center=True, onesided=False),
param(n_fft=400, center=True, onesided=False),
param(n_fft=300, center=True, onesided=False),
param(n_fft=400, hop_length=10),
param(n_fft=800, win_length=400, hop_length=20),
param(n_fft=800, win_length=400, hop_length=20, normalized=True),
param(),
param(n_fft=400, pad=32),
# These tests do not work - cause runtime error
# See https://github.com/pytorch/pytorch/issues/62323
# param(n_fft=400, center=False, onesided=True),
# param(n_fft=400, center=False, onesided=False),
]
)
def test_roundtrip_spectrogram(self, **args):
"""Test the spectrogram + inverse spectrogram results in approximate identity."""
......@@ -110,12 +107,14 @@ class TransformsTestBase(TestBaseMixin):
restored = inv_s.forward(transformed, length=waveform.shape[-1])
self.assertEqual(waveform, restored, atol=1e-6, rtol=1e-6)
@parameterized.expand([
param(0.5, 1, True, False),
param(0.5, 1, None, False),
param(1, 4, True, True),
param(1, 6, None, True),
])
@parameterized.expand(
[
param(0.5, 1, True, False),
param(0.5, 1, None, False),
param(1, 4, True, True),
param(1, 6, None, True),
]
)
def test_psd(self, duration, channel, mask, multi_mask):
"""Providing dtype changes the kernel cache dtype"""
transform = T.PSD(multi_mask)
......
from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import (
PytorchTestCase,
skipIfNoSox,
......@@ -9,6 +8,7 @@ from torchaudio_unittest.common_utils import (
@skipIfNoSox
class TestSoxUtils(PytorchTestCase):
"""Smoke tests for sox_util module"""
def test_set_seed(self):
"""`set_seed` does not crush"""
sox_utils.set_seed(0)
......@@ -34,16 +34,16 @@ class TestSoxUtils(PytorchTestCase):
"""`list_effects` returns the list of available effects"""
effects = sox_utils.list_effects()
# We cannot infer what effects are available, so only check some of them.
assert 'highpass' in effects
assert 'phaser' in effects
assert 'gain' in effects
assert "highpass" in effects
assert "phaser" in effects
assert "gain" in effects
def test_list_read_formats(self):
"""`list_read_formats` returns the list of supported formats"""
formats = sox_utils.list_read_formats()
assert 'wav' in formats
assert "wav" in formats
def test_list_write_formats(self):
"""`list_write_formats` returns the list of supported formats"""
formats = sox_utils.list_write_formats()
assert 'opus' not in formats
assert "opus" not in formats
......@@ -35,21 +35,15 @@ def _parse_args():
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--input-file", required=True, help="Input model file.")
parser.add_argument("--output-file", required=False, help="Output model file.")
parser.add_argument(
'--input-file', required=True,
help='Input model file.'
)
parser.add_argument(
'--output-file', required=False,
help='Output model file.'
)
parser.add_argument(
'--dict-dir',
"--dict-dir",
help=(
'Directory where letter vocabulary file, `dict.ltr.txt`, is found. '
'Required when loading wav2vec2 model. '
'https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt'
)
"Directory where letter vocabulary file, `dict.ltr.txt`, is found. "
"Required when loading wav2vec2 model. "
"https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt"
),
)
return parser.parse_args()
......@@ -57,9 +51,10 @@ def _parse_args():
def _load_model(input_file, dict_dir):
import fairseq
overrides = {} if dict_dir is None else {'data': dict_dir}
overrides = {} if dict_dir is None else {"data": dict_dir}
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[input_file], arg_overrides=overrides,
[input_file],
arg_overrides=overrides,
)
return models[0]
......@@ -67,7 +62,7 @@ def _load_model(input_file, dict_dir):
def _import_model(model):
from torchaudio.models.wav2vec2.utils import import_fairseq_model
if model.__class__.__name__ in ['HubertCtc', 'Wav2VecCtc']:
if model.__class__.__name__ in ["HubertCtc", "Wav2VecCtc"]:
model = model.w2v_encoder
model = import_fairseq_model(model)
return model
......@@ -75,10 +70,11 @@ def _import_model(model):
def _main(args):
import torch
model = _load_model(args.input_file, args.dict_dir)
model = _import_model(model)
torch.save(model.state_dict(), args.output_file)
if __name__ == '__main__':
if __name__ == "__main__":
_main(_parse_args())
......@@ -19,24 +19,19 @@ python convert_voxpopuli_models.py \
def _parse_args():
import argparse
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
'--input-file', required=True,
help='Input checkpoint file.'
)
parser.add_argument(
'--output-file', required=False,
help='Output model file.'
)
parser.add_argument("--input-file", required=True, help="Input checkpoint file.")
parser.add_argument("--output-file", required=False, help="Output model file.")
return parser.parse_args()
def _removeprefix(s, prefix):
if s.startswith(prefix):
return s[len(prefix):]
return s[len(prefix) :]
return s
......@@ -45,13 +40,13 @@ def _load(input_file):
from omegaconf import OmegaConf
data = torch.load(input_file)
cfg = OmegaConf.to_container(data['cfg'])
cfg = OmegaConf.to_container(data["cfg"])
for key in list(cfg.keys()):
if key != 'model':
if key != "model":
del cfg[key]
if 'w2v_args' in cfg['model']:
del cfg['model']['w2v_args'][key]
state_dict = {_removeprefix(k, 'w2v_encoder.'): v for k, v in data['model'].items()}
if "w2v_args" in cfg["model"]:
del cfg["model"]["w2v_args"][key]
state_dict = {_removeprefix(k, "w2v_encoder."): v for k, v in data["model"].items()}
return cfg, state_dict
......@@ -75,9 +70,9 @@ def _parse_model_param(cfg, state_dict):
"encoder_layerdrop": "encoder_layer_drop",
}
params = {}
src_dicts = [cfg['model']]
if 'w2v_args' in cfg['model']:
src_dicts.append(cfg['model']['w2v_args']['model'])
src_dicts = [cfg["model"]]
if "w2v_args" in cfg["model"]:
src_dicts.append(cfg["model"]["w2v_args"]["model"])
for src, tgt in key_mapping.items():
for model_cfg in src_dicts:
......@@ -89,12 +84,13 @@ def _parse_model_param(cfg, state_dict):
# the following line is commented out to resolve lint warning; uncomment before running script
# params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"])
assert len(params) == 15
params['aux_num_out'] = state_dict['proj.bias'].numel() if 'proj.bias' in state_dict else None
params["aux_num_out"] = state_dict["proj.bias"].numel() if "proj.bias" in state_dict else None
return params
def _main(args):
import json
import torch
import torchaudio
from torchaudio.models.wav2vec2.utils.import_fairseq import _convert_state_dict as _convert
......@@ -107,5 +103,5 @@ def _main(args):
torch.save(model.state_dict(), args.output_file)
if __name__ == '__main__':
if __name__ == "__main__":
_main(_parse_args())
......@@ -24,11 +24,11 @@ Features = namedtuple(
def _run_cmd(cmd):
return subprocess.check_output(cmd).decode('utf-8').strip()
return subprocess.check_output(cmd).decode("utf-8").strip()
def commit_title(commit_hash):
cmd = ['git', 'log', '-n', '1', '--pretty=format:%s', f'{commit_hash}']
cmd = ["git", "log", "-n", "1", "--pretty=format:%s", f"{commit_hash}"]
return _run_cmd(cmd)
......@@ -95,12 +95,12 @@ def get_features(commit_hash):
def get_commits_between(base_version, new_version):
cmd = ['git', 'merge-base', f'{base_version}', f'{new_version}']
cmd = ["git", "merge-base", f"{base_version}", f"{new_version}"]
merge_base = _run_cmd(cmd)
# Returns a list of items in the form
# a7854f33 Add HuBERT model architectures (#1769)
cmd = ['git', 'log', '--reverse', '--oneline', f'{merge_base}..{new_version}']
cmd = ["git", "log", "--reverse", "--oneline", f"{merge_base}..{new_version}"]
commits = _run_cmd(cmd)
log_lines = commits.split("\n")
......
import distutils.sysconfig
import os
import platform
import subprocess
from pathlib import Path
import distutils.sysconfig
import torch
from setuptools import Extension
from setuptools.command.build_ext import build_ext
import torch
__all__ = [
'get_ext_modules',
'CMakeBuild',
"get_ext_modules",
"CMakeBuild",
]
_THIS_DIR = Path(__file__).parent.resolve()
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
_TORCHAUDIO_DIR = _ROOT_DIR / 'torchaudio'
_TORCHAUDIO_DIR = _ROOT_DIR / "torchaudio"
def _get_build(var, default=False):
if var not in os.environ:
return default
val = os.environ.get(var, '0')
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
val = os.environ.get(var, "0")
trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"]
falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"]
if val in trues:
return True
if val not in falses:
print(
f'WARNING: Unexpected environment variable value `{var}={val}`. '
f'Expected one of {trues + falses}')
print(f"WARNING: Unexpected environment variable value `{var}={val}`. " f"Expected one of {trues + falses}")
return False
_BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX", True)
_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True)
_BUILD_SOX = False if platform.system() == "Windows" else _get_build("BUILD_SOX", True)
_BUILD_KALDI = False if platform.system() == "Windows" else _get_build("BUILD_KALDI", True)
_BUILD_RNNT = _get_build("BUILD_RNNT", True)
_BUILD_CTC_DECODER = False if platform.system() == 'Windows' else _get_build("BUILD_CTC_DECODER", True)
_BUILD_CTC_DECODER = False if platform.system() == "Windows" else _get_build("BUILD_CTC_DECODER", True)
_USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None)
_USE_OPENMP = _get_build("USE_OPENMP", True) and \
'ATen parallel backend: OpenMP' in torch.__config__.parallel_info()
_TORCH_CUDA_ARCH_LIST = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
_USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info()
_TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
def get_ext_modules():
modules = [
Extension(name='torchaudio.lib.libtorchaudio', sources=[]),
Extension(name='torchaudio._torchaudio', sources=[]),
Extension(name="torchaudio.lib.libtorchaudio", sources=[]),
Extension(name="torchaudio._torchaudio", sources=[]),
]
if _BUILD_CTC_DECODER:
modules.extend([
Extension(name='torchaudio.lib.libtorchaudio_decoder', sources=[]),
Extension(name='torchaudio._torchaudio_decoder', sources=[]),
])
modules.extend(
[
Extension(name="torchaudio.lib.libtorchaudio_decoder", sources=[]),
Extension(name="torchaudio._torchaudio_decoder", sources=[]),
]
)
return modules
......@@ -63,7 +62,7 @@ def get_ext_modules():
class CMakeBuild(build_ext):
def run(self):
try:
subprocess.check_output(['cmake', '--version'])
subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake is not available.") from None
super().run()
......@@ -75,11 +74,10 @@ class CMakeBuild(build_ext):
# However, the following `cmake` command will build all of them at the same time,
# so, we do not need to perform `cmake` twice.
# Therefore we call `cmake` only for `torchaudio._torchaudio`.
if ext.name != 'torchaudio._torchaudio':
if ext.name != "torchaudio._torchaudio":
return
extdir = os.path.abspath(
os.path.dirname(self.get_ext_fullpath(ext.name)))
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
# required for auto-detection of auxiliary "native" libs
if not extdir.endswith(os.path.sep):
......@@ -91,7 +89,7 @@ class CMakeBuild(build_ext):
f"-DCMAKE_BUILD_TYPE={cfg}",
f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}",
f"-DCMAKE_INSTALL_PREFIX={extdir}",
'-DCMAKE_VERBOSE_MAKEFILE=ON',
"-DCMAKE_VERBOSE_MAKEFILE=ON",
f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
......@@ -102,22 +100,21 @@ class CMakeBuild(build_ext):
f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}",
f"-DUSE_OPENMP:BOOL={'ON' if _USE_OPENMP else 'OFF'}",
]
build_args = [
'--target', 'install'
]
build_args = ["--target", "install"]
# Pass CUDA architecture to cmake
if _TORCH_CUDA_ARCH_LIST is not None:
# Convert MAJOR.MINOR[+PTX] list to new style one
# defined at https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html
_arches = _TORCH_CUDA_ARCH_LIST.replace('.', '').replace(' ', ';').split(";")
_arches = _TORCH_CUDA_ARCH_LIST.replace(".", "").replace(" ", ";").split(";")
_arches = [arch[:-4] if arch.endswith("+PTX") else f"{arch}-real" for arch in _arches]
cmake_args += [f"-DCMAKE_CUDA_ARCHITECTURES={';'.join(_arches)}"]
# Default to Ninja
if 'CMAKE_GENERATOR' not in os.environ or platform.system() == 'Windows':
if "CMAKE_GENERATOR" not in os.environ or platform.system() == "Windows":
cmake_args += ["-GNinja"]
if platform.system() == 'Windows':
if platform.system() == "Windows":
import sys
python_version = sys.version_info
cmake_args += [
"-DCMAKE_C_COMPILER=cl",
......@@ -137,14 +134,12 @@ class CMakeBuild(build_ext):
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
subprocess.check_call(
["cmake", str(_ROOT_DIR)] + cmake_args, cwd=self.build_temp)
subprocess.check_call(
["cmake", "--build", "."] + build_args, cwd=self.build_temp)
subprocess.check_call(["cmake", str(_ROOT_DIR)] + cmake_args, cwd=self.build_temp)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
def get_ext_filename(self, fullname):
ext_filename = super().get_ext_filename(fullname)
ext_filename_parts = ext_filename.split('.')
ext_filename_parts = ext_filename.split(".")
without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
ext_filename = '.'.join(without_abi)
ext_filename = ".".join(without_abi)
return ext_filename
......@@ -10,7 +10,6 @@ from torchaudio import (
sox_effects,
transforms,
)
from torchaudio.backend import (
list_audio_backends,
get_audio_backend,
......@@ -23,16 +22,16 @@ except ImportError:
pass
__all__ = [
'compliance',
'datasets',
'functional',
'models',
'pipelines',
'kaldi_io',
'utils',
'sox_effects',
'transforms',
'list_audio_backends',
'get_audio_backend',
'set_audio_backend',
"compliance",
"datasets",
"functional",
"models",
"pipelines",
"kaldi_io",
"utils",
"sox_effects",
"transforms",
"list_audio_backends",
"get_audio_backend",
"set_audio_backend",
]
......@@ -5,12 +5,12 @@ from pathlib import Path
import torch
from torchaudio._internal import module_utils as _mod_utils # noqa: F401
_LIB_DIR = Path(__file__).parent / 'lib'
_LIB_DIR = Path(__file__).parent / "lib"
def _get_lib_path(lib: str):
suffix = 'pyd' if os.name == 'nt' else 'so'
path = _LIB_DIR / f'{lib}.{suffix}'
suffix = "pyd" if os.name == "nt" else "so"
path = _LIB_DIR / f"{lib}.{suffix}"
return path
......@@ -26,11 +26,11 @@ def _load_lib(lib: str):
def _init_extension():
if not _mod_utils.is_module_available('torchaudio._torchaudio'):
warnings.warn('torchaudio C++ extension is not available.')
if not _mod_utils.is_module_available("torchaudio._torchaudio"):
warnings.warn("torchaudio C++ extension is not available.")
return
_load_lib('libtorchaudio')
_load_lib("libtorchaudio")
# This import is for initializing the methods registered via PyBind11
# This has to happen after the base library is loaded
from torchaudio import _torchaudio # noqa
......
import warnings
import importlib.util
from typing import Optional
import warnings
from functools import wraps
from typing import Optional
import torch
......@@ -28,14 +28,17 @@ def requires_module(*modules: str):
# fall through. If all the modules are available, no need to decorate
def decorator(func):
return func
else:
req = f'module: {missing[0]}' if len(missing) == 1 else f'modules: {missing}'
req = f"module: {missing[0]}" if len(missing) == 1 else f"modules: {missing}"
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires {req}')
raise RuntimeError(f"{func.__module__}.{func.__name__} requires {req}")
return wrapped
return decorator
......@@ -46,42 +49,51 @@ def deprecated(direction: str, version: Optional[str] = None):
direction (str): Migration steps to be given to users.
version (str or int): The version when the object will be removed
"""
def decorator(func):
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
message = (
f'{func.__module__}.{func.__name__} has been deprecated '
f"{func.__module__}.{func.__name__} has been deprecated "
f'and will be removed from {"future" if version is None else version} release. '
f'{direction}')
f"{direction}"
)
warnings.warn(message, stacklevel=2)
return func(*args, **kwargs)
return wrapped
return decorator
def is_kaldi_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_kaldi_available()
return is_module_available("torchaudio._torchaudio") and torch.ops.torchaudio.is_kaldi_available()
def requires_kaldi():
if is_kaldi_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires kaldi')
raise RuntimeError(f"{func.__module__}.{func.__name__} requires kaldi")
return wrapped
return decorator
def _check_soundfile_importable():
if not is_module_available('soundfile'):
if not is_module_available("soundfile"):
return False
try:
import soundfile # noqa: F401
import soundfile # noqa: F401
return True
except Exception:
warnings.warn("Failed to import soundfile. 'soundfile' backend is not available.")
......@@ -97,29 +109,39 @@ def is_soundfile_available():
def requires_soundfile():
if is_soundfile_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires soundfile')
raise RuntimeError(f"{func.__module__}.{func.__name__} requires soundfile")
return wrapped
return decorator
def is_sox_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available()
return is_module_available("torchaudio._torchaudio") and torch.ops.torchaudio.is_sox_available()
def requires_sox():
if is_sox_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires sox')
raise RuntimeError(f"{func.__module__}.{func.__name__} requires sox")
return wrapped
return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment