Commit ffeba11a authored by mayp777's avatar mayp777
Browse files

UPDATE

parent 29deb085
import io
import itertools
import tarfile
from pathlib import Path
from parameterized import parameterized
from torchaudio import sox_effects
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import (
get_sinusoid,
get_wav_data,
HttpServerMixin,
load_wav,
PytorchTestCase,
save_wav,
skipIfNoExec,
skipIfNoModule,
skipIfNoSox,
sox_utils,
TempDirMixin,
......@@ -23,10 +17,6 @@ from torchaudio_unittest.common_utils import (
from .common import load_params, name_func
if _mod_utils.is_module_available("requests"):
import requests
@skipIfNoSox
class TestSoxEffects(PytorchTestCase):
def test_init(self):
......@@ -241,136 +231,3 @@ class TestFileFormats(TempDirMixin, PytorchTestCase):
assert sr == expected_sr
self.assertEqual(found, expected)
@skipIfNoExec("sox")
@skipIfNoSox
class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand(
[
("wav", None),
("flac", 0),
("flac", 5),
("flac", 8),
("vorbis", -1),
("vorbis", 10),
("amb", None),
]
)
def test_fileobj(self, ext, compression):
"""Applying effects via file object works"""
sample_rate = 16000
channels_first = True
effects = [["band", "300", "10"]]
input_path = self.get_temp_path(f"input.{ext}")
reference_path = self.get_temp_path("reference.wav")
sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
with open(input_path, "rb") as fileobj:
found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first)
save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
@parameterized.expand(
[
("wav", None),
("flac", 0),
("flac", 5),
("flac", 8),
("vorbis", -1),
("vorbis", 10),
("amb", None),
]
)
def test_bytesio(self, ext, compression):
"""Applying effects via BytesIO object works"""
sample_rate = 16000
channels_first = True
effects = [["band", "300", "10"]]
input_path = self.get_temp_path(f"input.{ext}")
reference_path = self.get_temp_path("reference.wav")
sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
with open(input_path, "rb") as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first)
save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
@parameterized.expand(
[
("wav", None),
("flac", 0),
("flac", 5),
("flac", 8),
("vorbis", -1),
("vorbis", 10),
("amb", None),
]
)
def test_tarfile(self, ext, compression):
"""Applying effects to compressed audio via file-like file works"""
sample_rate = 16000
channels_first = True
effects = [["band", "300", "10"]]
audio_file = f"input.{ext}"
input_path = self.get_temp_path(audio_file)
reference_path = self.get_temp_path("reference.wav")
archive_path = self.get_temp_path("archive.tar.gz")
sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(input_path, arcname=audio_file)
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = sox_effects.apply_effects_file(fileobj, effects, channels_first=channels_first)
save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
@skipIfNoSox
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand(
[
("wav", None),
("flac", 0),
("flac", 5),
("flac", 8),
("vorbis", -1),
("vorbis", 10),
("amb", None),
]
)
def test_requests(self, ext, compression):
sample_rate = 16000
channels_first = True
effects = [["band", "300", "10"]]
audio_file = f"input.{ext}"
input_path = self.get_temp_path(audio_file)
reference_path = self.get_temp_path("reference.wav")
sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_effects.apply_effects_file(resp.raw, effects, channels_first=channels_first)
save_wav(self.get_temp_path("result.wav"), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfkmeMark
from .autograd_test_impl import AutogradTestFloat32, AutogradTestMixin
@skipIfNoCuda
@skipIfkmeMark
class AutogradCUDATest(AutogradTestMixin, PytorchTestCase):
device = "cuda"
......
......@@ -28,6 +28,7 @@ class AutogradTestMixin(TestBaseMixin):
inputs: List[torch.Tensor],
*,
nondet_tol: float = 0.0,
enable_all_grad: bool = True,
):
transform = transform.to(dtype=torch.float64, device=self.device)
......@@ -37,7 +38,8 @@ class AutogradTestMixin(TestBaseMixin):
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=torch.cdouble if i.is_complex() else torch.double, device=self.device)
i.requires_grad = True
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
......@@ -110,14 +112,14 @@ class AutogradTestMixin(TestBaseMixin):
sample_rate = 8000
transform = T.MFCC(sample_rate=sample_rate, log_mels=log_mels)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform])
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
@parameterized.expand([(False,), (True,)])
def test_lfcc(self, log_lf):
sample_rate = 8000
transform = T.LFCC(sample_rate=sample_rate, log_lf=log_lf)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform])
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
def test_compute_deltas(self):
transform = T.ComputeDeltas()
......@@ -187,8 +189,9 @@ class AutogradTestMixin(TestBaseMixin):
def test_melscale(self):
sample_rate = 8000
n_fft = 400
n_mels = n_fft // 2 + 1
transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels)
n_stft = n_fft // 2 + 1
n_mels = 128
transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels, n_stft=n_stft)
spec = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1
)
......@@ -317,6 +320,61 @@ class AutogradTestMixin(TestBaseMixin):
reference_channel = 0
self.assert_grad(transform, [specgram, psd_s, psd_n, reference_channel])
@nested_params(
["Convolve", "FFTConvolve"],
["full", "valid", "same"],
)
def test_convolve(self, cls, mode):
leading_dims = (4, 3, 2)
L_x, L_y = 23, 40
x = torch.rand(*leading_dims, L_x)
y = torch.rand(*leading_dims, L_y)
convolve = getattr(T, cls)(mode=mode)
self.assert_grad(convolve, [x, y])
def test_speed(self):
leading_dims = (3, 2)
time = 200
waveform = torch.rand(*leading_dims, time, requires_grad=True)
lengths = torch.randint(1, time, leading_dims)
speed = T.Speed(1000, 1.1)
self.assert_grad(speed, (waveform, lengths), enable_all_grad=False)
def test_speed_perturbation(self):
leading_dims = (3, 2)
time = 200
waveform = torch.rand(*leading_dims, time, requires_grad=True)
lengths = torch.randint(1, time, leading_dims)
speed = T.SpeedPerturbation(1000, [0.9])
self.assert_grad(speed, (waveform, lengths), enable_all_grad=False)
@nested_params([True, False])
def test_add_noise(self, use_lengths):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L)
noise = torch.rand(*leading_dims, L)
if use_lengths:
lengths = torch.rand(*leading_dims)
else:
lengths = None
snr = torch.rand(*leading_dims)
add_noise = T.AddNoise()
self.assert_grad(add_noise, (waveform, noise, snr, lengths))
def test_preemphasis(self):
waveform = torch.rand(3, 4, 10)
preemphasis = T.Preemphasis(coeff=0.97)
self.assert_grad(preemphasis, (waveform,))
def test_deemphasis(self):
waveform = torch.rand(3, 4, 10)
deemphasis = T.Deemphasis(coeff=0.97)
self.assert_grad(deemphasis, (waveform,))
class AutogradTestFloat32(TestBaseMixin):
def assert_grad(
......
......@@ -52,11 +52,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
n_mels = 32
n_stft = 5
mel_spec = torch.randn(3, 2, n_mels, 32) ** 2
transform = T.InverseMelScale(n_stft, n_mels)
transform = T.InverseMelScale(n_stft, n_mels, driver="gelsd")
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
self.assert_batch_consistency(transform, mel_spec, atol=1.0, rtol=1e-5)
self.assert_batch_consistency(transform, mel_spec)
def test_batch_compute_deltas(self):
specgram = torch.randn(3, 2, 31, 2786)
......@@ -257,3 +255,122 @@ class TestTransforms(common_utils.TorchaudioTestCase):
computed = transform(specgram, psd_s, psd_n, reference_channel)
self.assertEqual(computed, expected)
@common_utils.nested_params(
["Convolve", "FFTConvolve"],
["full", "valid", "same"],
)
def test_convolve(self, cls, mode):
leading_dims = (2, 3)
L_x, L_y = 89, 43
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
convolve = getattr(T, cls)(mode=mode)
actual = convolve(x, y)
expected = torch.stack(
[
torch.stack(
[convolve(x[i, j].unsqueeze(0), y[i, j].unsqueeze(0)).squeeze(0) for j in range(leading_dims[1])]
)
for i in range(leading_dims[0])
]
)
self.assertEqual(expected, actual)
def test_speed(self):
B = 5
orig_freq = 100
factor = 0.8
input_lengths = torch.randint(1, 1000, (B,), dtype=torch.int32)
speed = T.Speed(orig_freq, factor)
unbatched_input = [torch.ones((int(length),)) * 1.0 for length in input_lengths]
batched_input = torch.nn.utils.rnn.pad_sequence(unbatched_input, batch_first=True)
output, output_lengths = speed(batched_input, input_lengths)
unbatched_output = []
unbatched_output_lengths = []
for idx in range(len(unbatched_input)):
w, l = speed(unbatched_input[idx], input_lengths[idx])
unbatched_output.append(w)
unbatched_output_lengths.append(l)
self.assertEqual(output_lengths, torch.stack(unbatched_output_lengths))
for idx in range(len(unbatched_output)):
w, l = output[idx], output_lengths[idx]
self.assertEqual(unbatched_output[idx], w[:l])
def test_speed_perturbation(self):
B = 5
orig_freq = 100
factor = 0.8
input_lengths = torch.randint(1, 1000, (B,), dtype=torch.int32)
speed = T.SpeedPerturbation(orig_freq, [factor])
unbatched_input = [torch.ones((int(length),)) * 1.0 for length in input_lengths]
batched_input = torch.nn.utils.rnn.pad_sequence(unbatched_input, batch_first=True)
output, output_lengths = speed(batched_input, input_lengths)
unbatched_output = []
unbatched_output_lengths = []
for idx in range(len(unbatched_input)):
w, l = speed(unbatched_input[idx], input_lengths[idx])
unbatched_output.append(w)
unbatched_output_lengths.append(l)
self.assertEqual(output_lengths, torch.stack(unbatched_output_lengths))
for idx in range(len(unbatched_output)):
w, l = output[idx], output_lengths[idx]
self.assertEqual(unbatched_output[idx], w[:l])
def test_add_noise(self):
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
actual = add_noise(waveform, noise, snr, lengths)
expected = []
for i in range(leading_dims[0]):
for j in range(leading_dims[1]):
for k in range(leading_dims[2]):
expected.append(add_noise(waveform[i][j][k], noise[i][j][k], snr[i][j][k], lengths[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, L))
def test_preemphasis(self):
waveform = torch.rand((3, 5, 2, 100), dtype=self.dtype, device=self.device)
preemphasis = T.Preemphasis(coeff=0.97)
actual = preemphasis(waveform)
expected = []
for i in range(waveform.size(0)):
for j in range(waveform.size(1)):
for k in range(waveform.size(2)):
expected.append(preemphasis(waveform[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, waveform.size(-1)))
def test_deemphasis(self):
waveform = torch.rand((3, 5, 2, 100), dtype=self.dtype, device=self.device)
deemphasis = T.Deemphasis(coeff=0.97)
actual = deemphasis(waveform)
expected = []
for i in range(waveform.size(0)):
for j in range(waveform.size(1)):
for k in range(waveform.size(2)):
expected.append(deemphasis(waveform[i][j][k]))
self.assertEqual(torch.stack(expected), actual.reshape(-1, waveform.size(-1)))
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfkmeMark
from .librosa_compatibility_test_impl import TransformsTestBase
@skipIfNoCuda
@skipIfkmeMark
class TestTransforms(TransformsTestBase, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
......@@ -36,7 +36,7 @@ class TransformsTestBase(TestBaseMixin):
result = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=power,).to(self.device, self.dtype)(
waveform
)[0]
self.assertEqual(result, torch.from_numpy(expected), atol=1e-5, rtol=1e-5)
self.assertEqual(result, torch.from_numpy(expected), atol=1e-4, rtol=1e-4)
def test_Spectrogram_complex(self):
n_fft = 400
......@@ -54,7 +54,7 @@ class TransformsTestBase(TestBaseMixin):
result = T.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True,).to(
self.device, self.dtype
)(waveform)[0]
self.assertEqual(result.abs(), torch.from_numpy(expected), atol=1e-5, rtol=1e-5)
self.assertEqual(result.abs(), torch.from_numpy(expected), atol=1e-4, rtol=1e-4)
@nested_params(
[
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfkmeMark
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only
......@@ -11,6 +11,7 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
@skipIfNoCuda
@skipIfkmeMark
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
......@@ -4,7 +4,7 @@ import torch
import torchaudio.transforms as T
from parameterized import parameterized
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import skipIfRocm, TestBaseMixin, torch_script
from torchaudio_unittest.common_utils import skipIfRocm, TestBaseMixin, torch_script, skipIfkmeMark
class Transforms(TestBaseMixin):
......@@ -38,7 +38,8 @@ class Transforms(TestBaseMixin):
def test_Spectrogram_return_complex(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(power=None, return_complex=True), tensor)
@skipIfkmeMark
def test_InverseSpectrogram(self):
tensor = common_utils.get_whitenoise(sample_rate=8000)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
......@@ -118,6 +119,7 @@ class Transforms(TestBaseMixin):
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)
@skipIfkmeMark
def test_TimeStretch(self):
n_fft = 1025
n_freq = n_fft // 2 + 1
......@@ -144,12 +146,14 @@ class Transforms(TestBaseMixin):
pitch_shift(waveform)
self._assert_consistency(pitch_shift, waveform)
@skipIfkmeMark
def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(self.device)
self._assert_consistency_complex(T.PSD(), spectrogram)
@skipIfkmeMark
def test_PSD_with_mask(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
......@@ -167,6 +171,7 @@ class Transforms(TestBaseMixin):
["stv_power", False],
]
)
@skipIfkmeMark
def test_MVDR(self, solution, online):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
......@@ -174,6 +179,7 @@ class Transforms(TestBaseMixin):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n)
@skipIfkmeMark
def test_rtf_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
......@@ -182,7 +188,8 @@ class Transforms(TestBaseMixin):
psd_n = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device)
reference_channel = 0
self._assert_consistency_complex(T.RTFMVDR(), specgram, rtf, psd_n, reference_channel)
@skipIfkmeMark
def test_souden_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
......@@ -192,6 +199,85 @@ class Transforms(TestBaseMixin):
reference_channel = 0
self._assert_consistency_complex(T.SoudenMVDR(), specgram, psd_s, psd_n, reference_channel)
@common_utils.nested_params(
["Convolve", "FFTConvolve"],
["full", "valid", "same"],
)
def test_convolve(self, cls, mode):
leading_dims = (2, 3, 2)
L_x, L_y = 32, 55
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
convolve = getattr(T, cls)(mode=mode).to(device=self.device, dtype=self.dtype)
output = convolve(x, y)
ts_output = torch_script(convolve)(x, y)
self.assertEqual(ts_output, output)
@common_utils.nested_params([True, False])
def test_speed(self, use_lengths):
leading_dims = (3, 2)
time = 200
waveform = torch.rand(*leading_dims, time, dtype=self.dtype, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.randint(1, time, leading_dims, dtype=self.dtype, device=self.device)
else:
lengths = None
speed = T.Speed(1000, 0.9).to(self.device, self.dtype)
output = speed(waveform, lengths)
ts_output = torch_script(speed)(waveform, lengths)
self.assertEqual(ts_output, output)
@common_utils.nested_params([True, False])
def test_speed_perturbation(self, use_lengths):
leading_dims = (3, 2)
time = 200
waveform = torch.rand(*leading_dims, time, dtype=self.dtype, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.randint(1, time, leading_dims, dtype=self.dtype, device=self.device)
else:
lengths = None
speed = T.SpeedPerturbation(1000, [0.9]).to(self.device, self.dtype)
output = speed(waveform, lengths)
ts_output = torch_script(speed)(waveform, lengths)
self.assertEqual(ts_output, output)
@common_utils.nested_params([True, False])
def test_add_noise(self, use_lengths):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
else:
lengths = None
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
add_noise = T.AddNoise().to(self.device, self.dtype)
output = add_noise(waveform, noise, snr, lengths)
ts_output = torch_script(add_noise)(waveform, noise, snr, lengths)
self.assertEqual(ts_output, output)
def test_preemphasis(self):
waveform = torch.rand(3, 4, 10, dtype=self.dtype, device=self.device)
preemphasis = T.Preemphasis(coeff=0.97).to(dtype=self.dtype, device=self.device)
output = preemphasis(waveform)
ts_output = torch_script(preemphasis)(waveform)
self.assertEqual(ts_output, output)
def test_deemphasis(self):
waveform = torch.rand(3, 4, 10, dtype=self.dtype, device=self.device)
deemphasis = T.Deemphasis(coeff=0.97).to(dtype=self.dtype, device=self.device)
output = deemphasis(waveform)
ts_output = torch_script(deemphasis)(waveform)
self.assertEqual(ts_output, output)
class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfkmeMark
from .transforms_test_impl import TransformsTestBase
......@@ -11,6 +11,7 @@ class TransformsCUDAFloat32Test(TransformsTestBase, PytorchTestCase):
@skipIfNoCuda
@skipIfkmeMark
class TransformsCUDAFloat64Test(TransformsTestBase, PytorchTestCase):
device = "cuda"
dtype = torch.float64
......@@ -25,7 +25,6 @@ class Tester(common_utils.TorchaudioTestCase):
return waveform / factor
def test_mu_law_companding(self):
quantization_channels = 256
waveform = self.waveform.clone()
......@@ -237,7 +236,7 @@ class Tester(common_utils.TorchaudioTestCase):
torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method=invalid_resampling_method)
upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method="sinc_interpolation"
sample_rate, upsample_rate, resampling_method="sinc_interp_hann"
)
up_sampled = upsample_resample(waveform)
......@@ -245,7 +244,7 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
downsample_resample = torchaudio.transforms.Resample(
sample_rate, downsample_rate, resampling_method="sinc_interpolation"
sample_rate, downsample_rate, resampling_method="sinc_interp_hann"
)
down_sampled = downsample_resample(waveform)
......@@ -292,8 +291,7 @@ class SmokeTest(common_utils.TorchaudioTestCase):
self.assertEqual(specgram.onesided, False)
def test_melspectrogram(self):
melspecgram = transforms.MelSpectrogram(center=True, pad_mode="reflect", onesided=False)
melspecgram = transforms.MelSpectrogram(center=True, pad_mode="reflect")
specgram = melspecgram.spectrogram
self.assertEqual(specgram.center, True)
self.assertEqual(specgram.pad_mode, "reflect")
self.assertEqual(specgram.onesided, False)
import math
import random
from unittest.mock import patch
import numpy as np
import torch
import torchaudio.transforms as T
from parameterized import param, parameterized
from scipy import signal
from torchaudio.functional import lfilter, preemphasis
from torchaudio.functional.functional import _get_sinc_resample_kernel
from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, TestBaseMixin
from torchaudio_unittest.common_utils.psd_utils import psd_numpy
......@@ -11,11 +18,11 @@ def _get_ratio(mat):
class TransformsTestBase(TestBaseMixin):
def test_InverseMelScale(self):
def test_inverse_melscale(self):
"""Gauge the quality of InverseMelScale transform.
As InverseMelScale is currently implemented with
random initialization + iterative optimization,
sub-optimal solution (compute matrix inverse + relu),
it is not practically possible to assert the difference between
the estimated spectrogram and the original spectrogram as a whole.
Estimated spectrogram has very huge descrepency locally.
......@@ -53,7 +60,7 @@ class TransformsTestBase(TestBaseMixin):
assert _get_ratio(relative_diff < 1e-5) > 1e-5
@nested_params(
["sinc_interpolation", "kaiser_window"],
["sinc_interp_hann", "sinc_interp_kaiser"],
[16000, 44100],
)
def test_resample_identity(self, resampling_method, sample_rate):
......@@ -65,7 +72,7 @@ class TransformsTestBase(TestBaseMixin):
self.assertEqual(waveform, resampled)
@nested_params(
["sinc_interpolation", "kaiser_window"],
["sinc_interp_hann", "sinc_interp_kaiser"],
[None, torch.float64],
)
def test_resample_cache_dtype(self, resampling_method, dtype):
......@@ -158,3 +165,316 @@ class TransformsTestBase(TestBaseMixin):
trans.orig_freq, sample_rate, trans.gcd, device=self.device, dtype=self.dtype
)
self.assertEqual(trans.kernel, expected)
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
["full", "valid", "same"],
)
def test_convolve(self, leading_dims, lengths, mode):
"""Check that Convolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
convolve = T.Convolve(mode=mode).to(self.device)
actual = convolve(x, y)
num_signals = torch.tensor(leading_dims).prod() if leading_dims else 1
x_reshaped = x.reshape((num_signals, L_x))
y_reshaped = y.reshape((num_signals, L_y))
expected = [
signal.convolve(x_reshaped[i].detach().cpu().numpy(), y_reshaped[i].detach().cpu().numpy(), mode=mode)
for i in range(num_signals)
]
expected = torch.tensor(np.array(expected))
expected = expected.reshape(leading_dims + (-1,))
self.assertEqual(expected, actual)
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
["full", "valid", "same"],
)
def test_fftconvolve(self, leading_dims, lengths, mode):
"""Check that FFTConvolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
convolve = T.FFTConvolve(mode=mode).to(self.device)
actual = convolve(x, y)
expected = signal.fftconvolve(x.detach().cpu().numpy(), y.detach().cpu().numpy(), axes=-1, mode=mode)
expected = torch.tensor(expected)
self.assertEqual(expected, actual)
def test_speed_identity(self):
"""speed of 1.0 does not alter input waveform and length"""
leading_dims = (5, 4, 2)
time = 1000
waveform = torch.rand(*leading_dims, time)
lengths = torch.randint(1, 1000, leading_dims)
speed = T.Speed(1000, 1.0)
actual_waveform, actual_lengths = speed(waveform, lengths)
self.assertEqual(waveform, actual_waveform)
self.assertEqual(lengths, actual_lengths)
@nested_params([0.8, 1.1, 1.2], [True, False])
def test_speed_accuracy(self, factor, use_lengths):
"""sinusoidal waveform is properly compressed by factor"""
n_to_trim = 20
sample_rate = 1000
freq = 2
times = torch.arange(0, 5, 1.0 / sample_rate)
waveform = torch.cos(2 * math.pi * freq * times).unsqueeze(0).to(self.device, self.dtype)
if use_lengths:
lengths = torch.tensor([waveform.size(1)])
else:
lengths = None
speed = T.Speed(sample_rate, factor).to(self.device, self.dtype)
output, output_lengths = speed(waveform, lengths)
if use_lengths:
self.assertEqual(output.size(1), output_lengths[0])
else:
self.assertEqual(None, output_lengths)
new_times = torch.arange(0, 5 / factor, 1.0 / sample_rate)
expected_waveform = torch.cos(2 * math.pi * freq * factor * new_times).unsqueeze(0).to(self.device, self.dtype)
self.assertEqual(
expected_waveform[..., n_to_trim:-n_to_trim], output[..., n_to_trim:-n_to_trim], atol=1e-1, rtol=1e-4
)
def test_speed_perturbation(self):
"""sinusoidal waveform is properly compressed by sampled factors"""
n_to_trim = 20
sample_rate = 1000
freq = 2
times = torch.arange(0, 5, 1.0 / sample_rate)
waveform = torch.cos(2 * math.pi * freq * times).unsqueeze(0).to(self.device, self.dtype)
lengths = torch.tensor([waveform.size(1)])
factors = [0.8, 1.1, 1.0]
indices = random.choices(range(len(factors)), k=5)
speed_perturb = T.SpeedPerturbation(sample_rate, factors).to(self.device, self.dtype)
with patch("torch.randint", side_effect=indices):
for idx in indices:
output, output_lengths = speed_perturb(waveform, lengths)
self.assertEqual(output.size(1), output_lengths[0])
factor = factors[idx]
new_times = torch.arange(0, 5 / factor, 1.0 / sample_rate)
expected_waveform = (
torch.cos(2 * math.pi * freq * factor * new_times).unsqueeze(0).to(self.device, self.dtype)
)
self.assertEqual(
expected_waveform[..., n_to_trim:-n_to_trim],
output[..., n_to_trim:-n_to_trim],
atol=1e-1,
rtol=1e-4,
)
def test_add_noise_broadcast(self):
"""Check that AddNoise produces correct outputs when broadcasting input dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(5, 1, 1, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(5, 1, 3, dtype=self.dtype, device=self.device)
snr = torch.rand(1, 1, 1, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
actual = add_noise(waveform, noise, snr, lengths)
noise_expanded = noise.expand(*leading_dims, L)
snr_expanded = snr.expand(*leading_dims)
lengths_expanded = lengths.expand(*leading_dims)
expected = add_noise(waveform, noise_expanded, snr_expanded, lengths_expanded)
self.assertEqual(expected, actual)
@parameterized.expand(
[((5, 2, 3), (2, 1, 1), (5, 2), (5, 2, 3)), ((2, 1), (5,), (5,), (5,)), ((3,), (5, 2, 3), (2, 1, 1), (5, 2))]
)
def test_add_noise_leading_dim_check(self, waveform_dims, noise_dims, lengths_dims, snr_dims):
"""Check that AddNoise properly rejects inputs with different leading dimension lengths."""
L = 51
waveform = torch.rand(*waveform_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*noise_dims, L, dtype=self.dtype, device=self.device)
lengths = torch.rand(*lengths_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*snr_dims, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
with self.assertRaisesRegex(ValueError, "Input leading dimensions"):
add_noise(waveform, noise, snr, lengths)
def test_add_noise_length_check(self):
"""Check that add_noise properly rejects inputs that have inconsistent length dimensions."""
leading_dims = (5, 2, 3)
L = 51
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
noise = torch.rand(*leading_dims, 50, dtype=self.dtype, device=self.device)
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
add_noise = T.AddNoise()
with self.assertRaisesRegex(ValueError, "Length dimensions"):
add_noise(waveform, noise, snr, lengths)
@nested_params(
[(2, 1, 31)],
[0.97, 0.72],
)
def test_preemphasis(self, input_shape, coeff):
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
preemphasis = T.Preemphasis(coeff=coeff).to(dtype=self.dtype, device=self.device)
actual = preemphasis(waveform)
a_coeffs = torch.tensor([1.0, 0.0], device=self.device, dtype=self.dtype)
b_coeffs = torch.tensor([1.0, -coeff], device=self.device, dtype=self.dtype)
expected = lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
self.assertEqual(actual, expected)
@nested_params(
[(2, 1, 31)],
[0.97, 0.72],
)
def test_deemphasis(self, input_shape, coeff):
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
preemphasized = preemphasis(waveform, coeff=coeff)
deemphasis = T.Deemphasis(coeff=coeff).to(dtype=self.dtype, device=self.device)
deemphasized = deemphasis(preemphasized)
self.assertEqual(deemphasized, waveform)
@nested_params(
[(100, 200), (5, 10, 20), (50, 50, 100, 200)],
)
def test_time_masking(self, input_shape):
transform = T.TimeMasking(time_mask_param=5)
# Genearte a specgram tensor containing 1's only, for the ease of testing.
specgram = torch.ones(*input_shape)
masked = transform(specgram)
dim = len(input_shape)
# Across the axis (dim-1) where we apply masking,
# the mean tensor should contain equal elements,
# and the value should be between 0 and 1.
m_masked = torch.mean(masked, dim - 1)
self.assertEqual(torch.var(m_masked), 0)
self.assertTrue(torch.mean(m_masked) > 0)
self.assertTrue(torch.mean(m_masked) < 1)
# Across all other dimensions, the mean tensor should contain at least
# one zero element, and all non-zero elements should be 1.
for axis in range(dim - 1):
unmasked_axis_mean = torch.mean(masked, axis)
self.assertTrue(0 in unmasked_axis_mean)
self.assertFalse(False in torch.eq(unmasked_axis_mean[unmasked_axis_mean != 0], 1))
@nested_params(
[(100, 200), (5, 10, 20), (50, 50, 100, 200)],
)
def test_freq_masking(self, input_shape):
transform = T.FrequencyMasking(freq_mask_param=5)
# Genearte a specgram tensor containing 1's only, for the ease of testing.
specgram = torch.ones(*input_shape)
masked = transform(specgram)
dim = len(input_shape)
# Across the axis (dim-2) where we apply masking,
# the mean tensor should contain equal elements,
# and the value should be between 0 and 1.
m_masked = torch.mean(masked, dim - 2)
self.assertEqual(torch.var(m_masked), 0)
self.assertTrue(torch.mean(m_masked) > 0)
self.assertTrue(torch.mean(m_masked) < 1)
# Across all other dimensions, the mean tensor should contain at least
# one zero element, and all non-zero elements should be 1.
for axis in range(dim):
if axis != dim - 2:
unmasked_axis_mean = torch.mean(masked, axis)
self.assertTrue(0 in unmasked_axis_mean)
self.assertFalse(False in torch.eq(unmasked_axis_mean[unmasked_axis_mean != 0], 1))
@parameterized.expand(
[
param(10, 20, 10, 20, False),
param(0, 20, 10, 20, False),
param(10, 20, 0, 20, False),
param(10, 20, 10, 20, True),
param(0, 20, 10, 20, True),
param(10, 20, 0, 20, True),
]
)
def test_specaugment(self, n_time_masks, time_mask_param, n_freq_masks, freq_mask_param, iid_masks):
"""Make sure SpecAug masking works as expected"""
spec = torch.ones(2, 200, 200)
transform = T.SpecAugment(
n_time_masks=n_time_masks,
time_mask_param=time_mask_param,
n_freq_masks=n_freq_masks,
freq_mask_param=freq_mask_param,
iid_masks=iid_masks,
zero_masking=True,
)
spec_masked = transform(spec)
f_axis_mean = torch.mean(spec_masked, 1)
t_axis_mean = torch.mean(spec_masked, 2)
if n_time_masks == 0 and n_freq_masks == 0:
self.assertEqual(spec, spec_masked)
elif n_time_masks > 0 and n_freq_masks > 0:
# Across both time and frequency dimensions, the mean tensor should contain
# at least one zero element, and all non-zero elements should be less than 1.
self.assertTrue(0 in t_axis_mean)
self.assertFalse(False in torch.lt(t_axis_mean[t_axis_mean != 0], 1))
self.assertTrue(0 in f_axis_mean)
self.assertFalse(False in torch.lt(f_axis_mean[f_axis_mean != 0], 1))
elif n_freq_masks > 0:
# Across the frequency axis where we apply masking,
# the mean tensor should contain equal elements,
# and the value should be between 0 and 1.
self.assertFalse(False in torch.eq(f_axis_mean[0], f_axis_mean[0][0]))
self.assertFalse(False in torch.eq(f_axis_mean[1], f_axis_mean[1][0]))
self.assertTrue(f_axis_mean[0][0] < 1)
self.assertTrue(f_axis_mean[1][0] > 0)
# Across the time axis where we don't mask, the mean tensor should contain at
# least one zero element, and all non-zero elements should be 1.
self.assertTrue(0 in t_axis_mean)
self.assertFalse(False in torch.eq(t_axis_mean[t_axis_mean != 0], 1))
else:
# Across the time axis where we apply masking,
# the mean tensor should contain equal elements,
# and the value should be between 0 and 1.
self.assertFalse(False in torch.eq(t_axis_mean[0], t_axis_mean[0][0]))
self.assertFalse(False in torch.eq(t_axis_mean[1], t_axis_mean[1][0]))
self.assertTrue(t_axis_mean[0][0] < 1)
self.assertTrue(t_axis_mean[1][0] > 0)
# Across the frequency axis where we don't mask, the mean tensor should contain at
# least one zero element, and all non-zero elements should be 1.
self.assertTrue(0 in f_axis_mean)
self.assertFalse(False in torch.eq(f_axis_mean[f_axis_mean != 0], 1))
# Test if iid_masks gives different masking results for different spectrograms across the 0th dimension.
diff = torch.linalg.vector_norm(spec_masked[0] - spec_masked[1]).item()
print(diff)
if iid_masks is True:
self.assertTrue(diff > 0)
else:
self.assertTrue(diff == 0)
......@@ -25,3 +25,17 @@ class TestFFmpegUtils(PytorchTestCase):
"""`get_versions` does not crash"""
versions = ffmpeg_utils.get_versions()
assert set(versions.keys()) == {"libavutil", "libavcodec", "libavformat", "libavfilter", "libavdevice"}
def test_available_stuff(self):
"""get_encoders|decoders|muxers|demuxers|devices function does not segfault"""
ffmpeg_utils.get_demuxers()
ffmpeg_utils.get_muxers()
ffmpeg_utils.get_audio_decoders()
ffmpeg_utils.get_audio_encoders()
ffmpeg_utils.get_video_decoders()
ffmpeg_utils.get_video_encoders()
ffmpeg_utils.get_input_devices()
ffmpeg_utils.get_output_devices()
ffmpeg_utils.get_input_protocols()
ffmpeg_utils.get_output_protocols()
The Torchaudio repository and source distributions bundle several libraries that are
compatibly licensed. We list some here.
Name: cuctc
License: BSD-2-Clause (Files without specific notes)
BSD-3-Clause File:
torchaudio/csrc/cuctc/src/ctc_fast_divmod.cuh,
Apache 2.0 Files:
torchaudio/csrc/cuctc/src/bitonic_topk
For details, see: cuctc/LICENSE,
torchaudio/csrc/cuctc/src/bitonic_topk/LICENSE
################################################################################
# This file defines the following FFmpeg libraries using pre-built binaries.
add_library(ffmpeg4 INTERFACE)
add_library(ffmpeg5 INTERFACE)
add_library(ffmpeg6 INTERFACE)
################################################################################
include(FetchContent)
set(base_url https://pytorch.s3.amazonaws.com/torchaudio/ffmpeg)
if (APPLE)
if ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "arm64")
FetchContent_Declare(
f4
URL ${base_url}/2023-08-14/macos_arm64/4.4.4.tar.gz
URL_HASH SHA256=9a593eb241eb8b23bc557856ee6db5d9aecd2d8895c614a949f3a1ad9799c1a1
)
FetchContent_Declare(
f5
URL ${base_url}/2023-07-06/macos_arm64/5.0.3.tar.gz
URL_HASH SHA256=316fe8378afadcf63089acf3ad53a626fd3c26cc558b96ce1dc94d2a78f4deb4
)
FetchContent_Declare(
f6
URL ${base_url}/2023-07-06/macos_arm64/6.0.tar.gz
URL_HASH SHA256=5d1da9626f8cb817d6c558a2c61085a3d39a8d9f725a6f69f4658bea8efa9389
)
elseif ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64")
FetchContent_Declare(
f4
URL ${base_url}/2023-08-14/macos_x86_64/4.4.4.tar.gz
URL_HASH SHA256=0935e359c0864969987d908397f9208d6dc4dc0ef8bfe2ec730bb2c44eae89fc
)
FetchContent_Declare(
f5
URL ${base_url}/2023-07-06/macos_x86_64/5.0.3.tar.gz
URL_HASH SHA256=d0b49575d3b174cfcca53b3049641855e48028cf22dd32f3334bbec4ca94f43e
)
FetchContent_Declare(
f6
URL ${base_url}/2023-07-06/macos_x86_64/6.0.tar.gz
URL_HASH SHA256=eabc01eb7d9e714e484d5e1b27bf7d921e87c1f3c00334abd1729e158d6db862
)
else ()
message(
FATAL_ERROR
"${CMAKE_SYSTEM_PROCESSOR} is not supported for FFmpeg multi-version integration. "
"If you have FFmpeg libraries installed in the system,"
" setting FFMPEG_ROOT environment variable to the root directory of FFmpeg installation"
" (the directory where `include` and `lib` sub directories with corresponding headers"
" and library files are present) will invoke the FFmpeg single-version integration. "
"If you do not need the FFmpeg integration, setting USE_FFMPEG=0 will bypass the issue.")
endif()
elseif (UNIX)
if ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "aarch64")
FetchContent_Declare(
f4
URL ${base_url}/2023-08-14/linux_aarch64/4.4.4.tar.gz
URL_HASH SHA256=6f00437d13a3b3812ebe81c6e6f3a84a58f260d946a1995df87ba09aae234504
)
FetchContent_Declare(
f5
URL ${base_url}/2023-07-06/linux_aarch64/5.0.3.tar.gz
URL_HASH SHA256=65c663206982ee3f0ff88436d8869d191b46061e01e753518c77ecc13ea0236d
)
FetchContent_Declare(
f6
URL ${base_url}/2023-07-06/linux_aarch64/6.0.tar.gz
URL_HASH SHA256=ec762fd41ea7b8d9ad4f810f53fd78a565f2bc6f680afe56d555c80f3d35adef
)
elseif ("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64")
FetchContent_Declare(
f4
URL ${base_url}/2023-08-14/linux_x86_64/4.4.4.tar.gz
URL_HASH SHA256=9b87eeba9b6975e25f28ba12163bd713228ed84f4c2b3721bc5ebe92055edb51
)
FetchContent_Declare(
f5
URL ${base_url}/2023-07-06/linux_x86_64/5.0.3.tar.gz
URL_HASH SHA256=de3c75c99b9ce33de7efdc144566804ae5880457ce71e185b3f592dc452edce7
)
FetchContent_Declare(
f6
URL ${base_url}/2023-07-06/linux_x86_64/6.0.tar.gz
URL_HASH SHA256=04d3916404bab5efadd29f68361b7d13ea71e6242c6473edcb747a41a9fb97a6
)
else ()
# Possible case ppc64le (though it's not officially supported.)
message(
FATAL_ERROR
"${CMAKE_SYSTEM_PROCESSOR} is not supported for FFmpeg multi-version integration. "
"If you have FFmpeg libraries installed in the system,"
" setting FFMPEG_ROOT environment variable to the root directory of FFmpeg installation"
" (the directory where `include` and `lib` sub directories with corresponding headers"
" and library files are present) will invoke the FFmpeg single-version integration. "
"If you do not need the FFmpeg integration, setting USE_FFMPEG=0 will bypass the issue.")
endif()
elseif(MSVC)
FetchContent_Declare(
f4
URL ${base_url}/2023-08-14/windows/4.4.4.tar.gz
URL_HASH SHA256=9f9a65cf03a3e164edca601ba18180a504e44e03fae48ce706ca3120b55a4db5
)
FetchContent_Declare(
f5
URL ${base_url}/2023-07-06/windows/5.0.3.tar.gz
URL_HASH SHA256=e2daa10799909e366cb1b4b91a217d35f6749290dcfeea40ecae3d5b05a46cb3
)
FetchContent_Declare(
f6
URL ${base_url}/2023-07-06/windows/6.0.tar.gz
URL_HASH SHA256=098347eca8cddb5aaa61e9ecc1a00548c645fc59b4f7346b3d91414aa00a9cf6
)
endif()
FetchContent_MakeAvailable(f4 f5 f6)
target_include_directories(ffmpeg4 INTERFACE ${f4_SOURCE_DIR}/include)
target_include_directories(ffmpeg5 INTERFACE ${f5_SOURCE_DIR}/include)
target_include_directories(ffmpeg6 INTERFACE ${f6_SOURCE_DIR}/include)
if(APPLE)
target_link_libraries(
ffmpeg4
INTERFACE
${f4_SOURCE_DIR}/lib/libavutil.56.dylib
${f4_SOURCE_DIR}/lib/libavcodec.58.dylib
${f4_SOURCE_DIR}/lib/libavformat.58.dylib
${f4_SOURCE_DIR}/lib/libavdevice.58.dylib
${f4_SOURCE_DIR}/lib/libavfilter.7.dylib
)
target_link_libraries(
ffmpeg5
INTERFACE
${f5_SOURCE_DIR}/lib/libavutil.57.dylib
${f5_SOURCE_DIR}/lib/libavcodec.59.dylib
${f5_SOURCE_DIR}/lib/libavformat.59.dylib
${f5_SOURCE_DIR}/lib/libavdevice.59.dylib
${f5_SOURCE_DIR}/lib/libavfilter.8.dylib
)
target_link_libraries(
ffmpeg6
INTERFACE
${f6_SOURCE_DIR}/lib/libavutil.58.dylib
${f6_SOURCE_DIR}/lib/libavcodec.60.dylib
${f6_SOURCE_DIR}/lib/libavformat.60.dylib
${f6_SOURCE_DIR}/lib/libavdevice.60.dylib
${f6_SOURCE_DIR}/lib/libavfilter.9.dylib
)
elseif (UNIX)
target_link_libraries(
ffmpeg4
INTERFACE
${f4_SOURCE_DIR}/lib/libavutil.so.56
${f4_SOURCE_DIR}/lib/libavcodec.so.58
${f4_SOURCE_DIR}/lib/libavformat.so.58
${f4_SOURCE_DIR}/lib/libavdevice.so.58
${f4_SOURCE_DIR}/lib/libavfilter.so.7
)
target_link_libraries(
ffmpeg5
INTERFACE
${f5_SOURCE_DIR}/lib/libavutil.so.57
${f5_SOURCE_DIR}/lib/libavcodec.so.59
${f5_SOURCE_DIR}/lib/libavformat.so.59
${f5_SOURCE_DIR}/lib/libavdevice.so.59
${f5_SOURCE_DIR}/lib/libavfilter.so.8
)
target_link_libraries(
ffmpeg6
INTERFACE
${f6_SOURCE_DIR}/lib/libavutil.so.58
${f6_SOURCE_DIR}/lib/libavcodec.so.60
${f6_SOURCE_DIR}/lib/libavformat.so.60
${f6_SOURCE_DIR}/lib/libavdevice.so.60
${f6_SOURCE_DIR}/lib/libavfilter.so.9
)
elseif(MSVC)
target_link_libraries(
ffmpeg4
INTERFACE
${f4_SOURCE_DIR}/bin/avutil.lib
${f4_SOURCE_DIR}/bin/avcodec.lib
${f4_SOURCE_DIR}/bin/avformat.lib
${f4_SOURCE_DIR}/bin/avdevice.lib
${f4_SOURCE_DIR}/bin/avfilter.lib
)
target_link_libraries(
ffmpeg5
INTERFACE
${f5_SOURCE_DIR}/bin/avutil.lib
${f5_SOURCE_DIR}/bin/avcodec.lib
${f5_SOURCE_DIR}/bin/avformat.lib
${f5_SOURCE_DIR}/bin/avdevice.lib
${f5_SOURCE_DIR}/bin/avfilter.lib
)
target_link_libraries(
ffmpeg6
INTERFACE
${f6_SOURCE_DIR}/bin/avutil.lib
${f6_SOURCE_DIR}/bin/avcodec.lib
${f6_SOURCE_DIR}/bin/avformat.lib
${f6_SOURCE_DIR}/bin/avdevice.lib
${f6_SOURCE_DIR}/bin/avfilter.lib
)
endif()
# CMake file for searching existing FFmpeg installation and defining ffmpeg TARGET
message(STATUS "Searching existing FFmpeg installation")
message(STATUS FFMPEG_ROOT=$ENV{FFMPEG_ROOT})
if (NOT DEFINED ENV{FFMPEG_ROOT})
message(FATAL_ERROR "Environment variable FFMPEG_ROOT is not set.")
endif()
set(_root $ENV{FFMPEG_ROOT})
set(lib_dirs "${_root}/lib" "${_root}/bin")
set(include_dir "${_root}/include")
add_library(ffmpeg INTERFACE)
target_include_directories(ffmpeg INTERFACE "${include_dir}")
function (_find_ffmpeg_lib component)
find_path("${component}_header"
NAMES "lib${component}/${component}.h"
PATHS "${include_dir}"
DOC "The include directory for ${component}"
REQUIRED
NO_DEFAULT_PATH)
find_library("lib${component}"
NAMES "${component}"
PATHS ${lib_dirs}
DOC "${component} library"
REQUIRED
NO_DEFAULT_PATH)
message(STATUS "Found ${component}: ${lib${component}}")
target_link_libraries(
ffmpeg
INTERFACE
${lib${component}})
endfunction ()
_find_ffmpeg_lib(avutil)
_find_ffmpeg_lib(avcodec)
_find_ffmpeg_lib(avformat)
_find_ffmpeg_lib(avdevice)
_find_ffmpeg_lib(avfilter)
find_package(PkgConfig REQUIRED)
include(FetchContent)
include(ExternalProject)
set(INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../install)
set(ARCHIVE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../archives)
set(patch_dir ${PROJECT_SOURCE_DIR}/third_party/patches)
set(COMMON_ARGS --quiet --disable-shared --enable-static --prefix=${INSTALL_DIR} --with-pic --disable-dependency-tracking --disable-debug --disable-examples --disable-doc)
# To pass custom environment variables to ExternalProject_Add command,
# we need to do `${CMAKE_COMMAND} -E env ${envs} <COMMANAD>`.
# https://stackoverflow.com/a/62437353
# We constrcut the custom environment variables here
set(envs
"PKG_CONFIG_PATH=${INSTALL_DIR}/lib/pkgconfig"
"LDFLAGS=-L${INSTALL_DIR}/lib $ENV{LDFLAGS}"
"CFLAGS=-I${INSTALL_DIR}/include -fvisibility=hidden $ENV{CFLAGS}"
)
ExternalProject_Add(amr
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://sourceforge.net/projects/opencore-amr/files/opencore-amr/opencore-amr-0.1.5.tar.gz
URL_HASH SHA256=2c006cb9d5f651bfb5e60156dbff6af3c9d35c7bbcc9015308c0aff1e14cd341
PATCH_COMMAND cp ${patch_dir}/config.guess ${patch_dir}/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/amr/
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/amr/configure ${COMMON_ARGS}
DOWNLOAD_NO_PROGRESS ON
LOG_DOWNLOAD ON
LOG_UPDATE ON
LOG_CONFIGURE ON
LOG_BUILD ON
LOG_INSTALL ON
LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
ExternalProject_Add(lame
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://downloads.sourceforge.net/project/lame/lame/3.99/lame-3.99.5.tar.gz
URL_HASH SHA256=24346b4158e4af3bd9f2e194bb23eb473c75fb7377011523353196b19b9a23ff
PATCH_COMMAND cp ${patch_dir}/config.guess ${patch_dir}/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/lame/
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/lame/configure ${COMMON_ARGS} --enable-nasm
DOWNLOAD_NO_PROGRESS ON
LOG_DOWNLOAD ON
LOG_UPDATE ON
LOG_CONFIGURE ON
LOG_BUILD ON
LOG_INSTALL ON
LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
ExternalProject_Add(ogg
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://ftp.osuosl.org/pub/xiph/releases/ogg/libogg-1.3.3.tar.gz
URL_HASH SHA256=c2e8a485110b97550f453226ec644ebac6cb29d1caef2902c007edab4308d985
PATCH_COMMAND cp ${patch_dir}/config.guess ${patch_dir}/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/ogg/
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/ogg/configure ${COMMON_ARGS}
DOWNLOAD_NO_PROGRESS ON
LOG_DOWNLOAD ON
LOG_UPDATE ON
LOG_CONFIGURE ON
LOG_BUILD ON
LOG_INSTALL ON
LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
ExternalProject_Add(flac
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ogg
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://ftp.osuosl.org/pub/xiph/releases/flac/flac-1.3.2.tar.xz
URL_HASH SHA256=91cfc3ed61dc40f47f050a109b08610667d73477af6ef36dcad31c31a4a8d53f
PATCH_COMMAND cp ${patch_dir}/config.guess ${patch_dir}/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/flac/
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/flac/configure ${COMMON_ARGS} --with-ogg --disable-cpplibs
DOWNLOAD_NO_PROGRESS ON
LOG_DOWNLOAD ON
LOG_UPDATE ON
LOG_CONFIGURE ON
LOG_BUILD ON
LOG_INSTALL ON
LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
ExternalProject_Add(vorbis
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ogg
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://ftp.osuosl.org/pub/xiph/releases/vorbis/libvorbis-1.3.6.tar.gz
URL_HASH SHA256=6ed40e0241089a42c48604dc00e362beee00036af2d8b3f46338031c9e0351cb
PATCH_COMMAND cp ${patch_dir}/config.guess ${patch_dir}/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/vorbis/
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/vorbis/configure ${COMMON_ARGS} --with-ogg
DOWNLOAD_NO_PROGRESS ON
LOG_DOWNLOAD ON
LOG_UPDATE ON
LOG_CONFIGURE ON
LOG_BUILD ON
LOG_INSTALL ON
LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
ExternalProject_Add(opus
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ogg
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://ftp.osuosl.org/pub/xiph/releases/opus/opus-1.3.1.tar.gz
URL_HASH SHA256=65b58e1e25b2a114157014736a3d9dfeaad8d41be1c8179866f144a2fb44ff9d
PATCH_COMMAND cp ${patch_dir}/config.guess ${patch_dir}/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/opus/
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/opus/configure ${COMMON_ARGS} --with-ogg
DOWNLOAD_NO_PROGRESS ON
LOG_DOWNLOAD ON
LOG_UPDATE ON
LOG_CONFIGURE ON
LOG_BUILD ON
LOG_INSTALL ON
LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
ExternalProject_Add(opusfile
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS opus
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://ftp.osuosl.org/pub/xiph/releases/opus/opusfile-0.12.tar.gz
URL_HASH SHA256=118d8601c12dd6a44f52423e68ca9083cc9f2bfe72da7a8c1acb22a80ae3550b
PATCH_COMMAND cp ${patch_dir}/config.guess ${patch_dir}/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/opusfile/
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/opusfile/configure ${COMMON_ARGS} --disable-http
DOWNLOAD_NO_PROGRESS ON
LOG_DOWNLOAD ON
LOG_UPDATE ON
LOG_CONFIGURE ON
LOG_BUILD ON
LOG_INSTALL ON
LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
# OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses.
# See https://github.com/pytorch/audio/pull/1026
# TODO: Add flags like https://github.com/suphoff/pytorch_parallel_extension_cpp/blob/master/setup.py
set(SOX_OPTIONS
--disable-openmp
--with-amrnb
--with-amrwb
--with-flac
--with-lame
--with-oggvorbis
--with-opus
--without-alsa
--without-ao
--without-coreaudio
--without-oss
--without-id3tag
--without-ladspa
--without-mad
--without-magic
--without-png
--without-pulseaudio
--without-sndfile
--without-sndio
--without-sunaudio
--without-waveaudio
--without-wavpack
--without-twolame
)
set(SOX_LIBRARIES
${INSTALL_DIR}/lib/libsox.a
${INSTALL_DIR}/lib/libopencore-amrnb.a
${INSTALL_DIR}/lib/libopencore-amrwb.a
${INSTALL_DIR}/lib/libmp3lame.a
${INSTALL_DIR}/lib/libFLAC.a
${INSTALL_DIR}/lib/libopusfile.a
${INSTALL_DIR}/lib/libopus.a
${INSTALL_DIR}/lib/libvorbisenc.a
${INSTALL_DIR}/lib/libvorbisfile.a
${INSTALL_DIR}/lib/libvorbis.a
${INSTALL_DIR}/lib/libogg.a
)
set(sox_depends
ogg flac vorbis opusfile lame amr
)
ExternalProject_Add(sox
PREFIX ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${sox_depends}
DOWNLOAD_DIR ${ARCHIVE_DIR}
FetchContent_Declare(
sox_src
URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2
URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c
PATCH_COMMAND patch -p1 < ${patch_dir}/sox.patch && cp ${patch_dir}/config.guess ${patch_dir}/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/sox/
CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/sox/configure ${COMMON_ARGS} ${SOX_OPTIONS}
BUILD_BYPRODUCTS ${SOX_LIBRARIES}
DOWNLOAD_NO_PROGRESS ON
LOG_DOWNLOAD ON
LOG_UPDATE ON
LOG_CONFIGURE ON
LOG_BUILD ON
LOG_INSTALL ON
LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
add_library(libsox INTERFACE)
add_dependencies(libsox sox)
target_include_directories(libsox INTERFACE ${INSTALL_DIR}/include)
target_link_libraries(libsox INTERFACE ${SOX_LIBRARIES})
PATCH_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
)
# FetchContent_MakeAvailable will parse the downloaded content and setup the targets.
# We want to only download and not build, so we run Populate manually.
if(NOT sox_src_POPULATED)
FetchContent_Populate(sox_src)
endif()
add_library(sox SHARED stub.c)
if(APPLE)
set_target_properties(sox PROPERTIES SUFFIX .dylib)
endif(APPLE)
target_include_directories(sox PUBLIC ${sox_src_SOURCE_DIR}/src)
#include <sox.h>
int sox_add_effect(
sox_effects_chain_t* chain,
sox_effect_t* effp,
sox_signalinfo_t* in,
sox_signalinfo_t const* out) {
return -1;
}
int sox_close(sox_format_t* ft) {
return -1;
}
sox_effect_t* sox_create_effect(sox_effect_handler_t const* eh) {
return NULL;
}
sox_effects_chain_t* sox_create_effects_chain(
sox_encodinginfo_t const* in_enc,
sox_encodinginfo_t const* out_enc) {
return NULL;
}
void sox_delete_effect(sox_effect_t* effp) {}
void sox_delete_effects_chain(sox_effects_chain_t* ecp) {}
int sox_effect_options(sox_effect_t* effp, int argc, char* const argv[]) {
return -1;
}
const sox_effect_handler_t* sox_find_effect(char const* name) {
return NULL;
}
int sox_flow_effects(
sox_effects_chain_t* chain,
int callback(sox_bool all_done, void* client_data),
void* client_data) {
return -1;
}
const sox_effect_fn_t* sox_get_effect_fns(void) {
return NULL;
}
const sox_format_tab_t* sox_get_format_fns(void) {
return NULL;
}
sox_globals_t* sox_get_globals(void) {
return NULL;
}
sox_format_t* sox_open_read(
char const* path,
sox_signalinfo_t const* signal,
sox_encodinginfo_t const* encoding,
char const* filetype) {
return NULL;
}
sox_format_t* sox_open_write(
char const* path,
sox_signalinfo_t const* signal,
sox_encodinginfo_t const* encoding,
char const* filetype,
sox_oob_t const* oob,
sox_bool overwrite_permitted(char const* filename)) {
return NULL;
}
const char* sox_strerror(int sox_errno) {
return NULL;
}
size_t sox_write(sox_format_t* ft, const sox_sample_t* buf, size_t len) {
return 0;
}
int sox_init() {
return -1;
};
int sox_quit() {
return -1;
};
......@@ -14,20 +14,22 @@ primary_labels_mapping = {
"bug fix": "Bug Fixes",
"new feature": "New Features",
"improvement": "Improvements",
"example": "Examples",
"prototype": "Prototypes",
"other": "Other",
"None": "Missing",
}
secondary_labels_mapping = {
"module: I/O": "I/O",
"module: io": "I/O",
"module: ops": "Ops",
"module: models": "Models",
"module: pipelines": "Pipelines",
"module: datasets": "Datasets",
"module: docs": "Documentation",
"module: tests": "Tests",
"tutorial": "Tutorials",
"recipe": "Recipes",
"example": "Examples",
"build": "Build",
"style": "Style",
"perf": "Performance",
......
......@@ -7,7 +7,7 @@ from pathlib import Path
import torch
from setuptools import Extension
from setuptools.command.build_ext import build_ext
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
from torch.utils.cpp_extension import CUDA_HOME
__all__ = [
"get_ext_modules",
......@@ -34,12 +34,13 @@ def _get_build(var, default=False):
_BUILD_SOX = False if platform.system() == "Windows" else _get_build("BUILD_SOX", True)
_BUILD_KALDI = False if platform.system() == "Windows" else _get_build("BUILD_KALDI", True)
_BUILD_RIR = _get_build("BUILD_RIR", True)
_BUILD_RNNT = _get_build("BUILD_RNNT", True)
_BUILD_CTC_DECODER = _get_build("BUILD_CTC_DECODER", True)
_USE_FFMPEG = _get_build("USE_FFMPEG", False)
_USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None)
_USE_FFMPEG = _get_build("USE_FFMPEG", True)
_USE_ROCM = _get_build("USE_ROCM", torch.backends.cuda.is_built() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.backends.cuda.is_built() and torch.version.hip is None)
_BUILD_ALIGN = _get_build("BUILD_ALIGN", True)
_BUILD_CUDA_CTC_DECODER = _get_build("BUILD_CUDA_CTC_DECODER", _USE_CUDA)
_USE_OPENMP = _get_build("USE_OPENMP", True) and "ATen parallel backend: OpenMP" in torch.__config__.parallel_info()
_TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
......@@ -47,23 +48,42 @@ _TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
def get_ext_modules():
modules = [
Extension(name="torchaudio.lib.libtorchaudio", sources=[]),
Extension(name="torchaudio._torchaudio", sources=[]),
Extension(name="torchaudio.lib._torchaudio", sources=[]),
]
if _BUILD_CTC_DECODER:
if _BUILD_SOX:
modules.extend(
[
Extension(name="torchaudio.lib.libflashlight-text", sources=[]),
Extension(name="torchaudio.flashlight_lib_text_decoder", sources=[]),
Extension(name="torchaudio.flashlight_lib_text_dictionary", sources=[]),
Extension(name="torchaudio.lib.libtorchaudio_sox", sources=[]),
Extension(name="torchaudio.lib._torchaudio_sox", sources=[]),
]
)
if _USE_FFMPEG:
if _BUILD_CUDA_CTC_DECODER:
modules.extend(
[
Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[]),
Extension(name="torchaudio._torchaudio_ffmpeg", sources=[]),
Extension(name="torchaudio.lib.libctc_prefix_decoder", sources=[]),
Extension(name="torchaudio.lib.pybind11_prefixctc", sources=[]),
]
)
if _USE_FFMPEG:
if "FFMPEG_ROOT" in os.environ:
# single version ffmpeg mode
modules.extend(
[
Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[]),
Extension(name="torchaudio.lib._torchaudio_ffmpeg", sources=[]),
]
)
else:
modules.extend(
[
Extension(name="torchaudio.lib.libtorchaudio_ffmpeg4", sources=[]),
Extension(name="torchaudio.lib._torchaudio_ffmpeg4", sources=[]),
Extension(name="torchaudio.lib.libtorchaudio_ffmpeg5", sources=[]),
Extension(name="torchaudio.lib._torchaudio_ffmpeg5", sources=[]),
Extension(name="torchaudio.lib.libtorchaudio_ffmpeg6", sources=[]),
Extension(name="torchaudio.lib._torchaudio_ffmpeg6", sources=[]),
]
)
return modules
......@@ -84,10 +104,16 @@ class CMakeBuild(build_ext):
# However, the following `cmake` command will build all of them at the same time,
# so, we do not need to perform `cmake` twice.
# Therefore we call `cmake` only for `torchaudio._torchaudio`.
if ext.name != "torchaudio._torchaudio":
if ext.name != "torchaudio.lib.libtorchaudio":
return
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
# Note:
# the last part "lib" does not really matter. We want to get the full path of
# the root build directory. Passing "torchaudio" will be interpreted as
# `torchaudio.[so|dylib|pyd]`, so we need something `torchaudio.foo`, that is
# interpreted as `torchaudio/foo.so` then use dirname to get the `torchaudio`
# directory.
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath("torchaudio.lib")))
# required for auto-detection of auxiliary "native" libs
if not extdir.endswith(os.path.sep):
......@@ -102,9 +128,10 @@ class CMakeBuild(build_ext):
"-DCMAKE_VERBOSE_MAKEFILE=ON",
f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
f"-DBUILD_RIR:BOOL={'ON' if _BUILD_RIR else 'OFF'}",
f"-DBUILD_RNNT:BOOL={'ON' if _BUILD_RNNT else 'OFF'}",
f"-DBUILD_CTC_DECODER:BOOL={'ON' if _BUILD_CTC_DECODER else 'OFF'}",
f"-DBUILD_ALIGN:BOOL={'ON' if _BUILD_ALIGN else 'OFF'}",
f"-DBUILD_CUDA_CTC_DECODER:BOOL={'ON' if _BUILD_CUDA_CTC_DECODER else 'OFF'}",
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}",
f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}",
......@@ -123,10 +150,6 @@ class CMakeBuild(build_ext):
if platform.system() != "Windows" and CUDA_HOME is not None:
cmake_args += [f"-DCMAKE_CUDA_COMPILER='{CUDA_HOME}/bin/nvcc'"]
cmake_args += [f"-DCUDA_TOOLKIT_ROOT_DIR='{CUDA_HOME}'"]
if platform.system() != "Windows" and ROCM_HOME is not None:
cmake_args += [f"-DCMAKE_HIP_COMPILER='{ROCM_HOME}/bin/hipcc'"]
cmake_args += [f"-DCUDA_TOOLKIT_ROOT_DIR='{ROCM_HOME}'"]
# Default to Ninja
if "CMAKE_GENERATOR" not in os.environ or platform.system() == "Windows":
......
from torchaudio import ( # noqa: F401
from . import ( # noqa: F401
_extension,
compliance,
datasets,
......@@ -11,7 +11,7 @@ from torchaudio import ( # noqa: F401
transforms,
utils,
)
from torchaudio.backend import get_audio_backend, list_audio_backends, set_audio_backend
from ._backend.common import AudioMetaData # noqa
try:
from .version import __version__, git_version, abi, dtk, torch_version, dcu_version # noqa: F401
......@@ -19,7 +19,27 @@ try:
except ImportError:
pass
def _is_backend_dispatcher_enabled():
import os
return os.getenv("TORCHAUDIO_USE_BACKEND_DISPATCHER", default="1") == "1"
if _is_backend_dispatcher_enabled():
from ._backend import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend
else:
from .backend import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend
_init_backend()
# for backward compatibility. This has to happen after _backend is imported.
from . import backend # noqa: F401
__all__ = [
"AudioMetaData",
"io",
"compliance",
"datasets",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment