Commit 406e9c8d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add resample option to AudioEffector (#3374)

Summary:
Currently, AudioEffector always resample to the original sample rate. It is more flexible to allow overriding this to any sample rate.

Pull Request resolved: https://github.com/pytorch/audio/pull/3374

Differential Revision: D46235358

Pulled By: mthrok

fbshipit-source-id: 39a5d4e38d9b90380da31d0ce9ee8090668b54e4
parent 58a51b5b
...@@ -83,3 +83,20 @@ class EffectorTest(TorchaudioTestCase): ...@@ -83,3 +83,20 @@ class EffectorTest(TorchaudioTestCase):
output = effector.apply(original, sample_rate) output = effector.apply(original, sample_rate)
self.assertEqual(original.shape, output.shape) self.assertEqual(original.shape, output.shape)
def test_resample(self):
"""Resample option allows to change the sampling rate"""
sample_rate = 8000
output_sample_rate = 16000
num_channels = 3
effector = AudioEffector(effect="lowpass")
original = get_sinusoid(n_channels=num_channels, sample_rate=sample_rate, channels_first=False)
output = effector.apply(original, sample_rate, output_sample_rate)
self.assertEqual(output.shape, [output_sample_rate, num_channels])
for chunk in effector.stream(
original, sample_rate, output_sample_rate=output_sample_rate, frames_per_chunk=output_sample_rate
):
self.assertEqual(chunk.shape, [output_sample_rate, num_channels])
...@@ -260,7 +260,7 @@ class AudioEffector: ...@@ -260,7 +260,7 @@ class AudioEffector:
self.codec_config = codec_config self.codec_config = codec_config
self.pad_end = pad_end self.pad_end = pad_end
def _get_reader(self, waveform, sample_rate, frames_per_chunk=None): def _get_reader(self, waveform, sample_rate, output_sample_rate, frames_per_chunk=None):
num_frames, num_channels = waveform.shape num_frames, num_channels = waveform.shape
if self.format is not None: if self.format is not None:
...@@ -283,7 +283,8 @@ class AudioEffector: ...@@ -283,7 +283,8 @@ class AudioEffector:
waveform, sample_rate, self.effect, muxer, encoder, self.codec_config, frames_per_chunk waveform, sample_rate, self.effect, muxer, encoder, self.codec_config, frames_per_chunk
) )
filter_desc = _get_afilter_desc(sample_rate, _get_sample_fmt(waveform.dtype), num_channels) output_sr = sample_rate if output_sample_rate is None else output_sample_rate
filter_desc = _get_afilter_desc(output_sr, _get_sample_fmt(waveform.dtype), num_channels)
if self.pad_end: if self.pad_end:
filter_desc = f"{filter_desc},apad=whole_len={num_frames}" filter_desc = f"{filter_desc},apad=whole_len={num_frames}"
...@@ -291,12 +292,17 @@ class AudioEffector: ...@@ -291,12 +292,17 @@ class AudioEffector:
reader.add_audio_stream(frames_per_chunk or -1, -1, filter_desc=filter_desc) reader.add_audio_stream(frames_per_chunk or -1, -1, filter_desc=filter_desc)
return reader return reader
def apply(self, waveform: Tensor, sample_rate: int) -> Tensor: def apply(self, waveform: Tensor, sample_rate: int, output_sample_rate: Optional[int] = None) -> Tensor:
"""Apply the effect and/or codecs to the whole tensor. """Apply the effect and/or codecs to the whole tensor.
Args: Args:
waveform (Tensor): The input waveform. Shape: ``(time, channel)`` waveform (Tensor): The input waveform. Shape: ``(time, channel)``
sample_rate (int): Sample rate of the waveform. sample_rate (int): Sample rate of the input waveform.
output_sample_rate (int or None, optional): Output sample rate.
If provided, override the output sample rate.
Otherwise, the resulting tensor is resampled to have
the same sample rate as the input.
Default: ``None``.
Returns: Returns:
Tensor: Tensor:
...@@ -309,18 +315,25 @@ class AudioEffector: ...@@ -309,18 +315,25 @@ class AudioEffector:
if waveform.numel() == 0: if waveform.numel() == 0:
return waveform return waveform
reader = self._get_reader(waveform, sample_rate) reader = self._get_reader(waveform, sample_rate, output_sample_rate)
reader.process_all_packets() reader.process_all_packets()
(applied,) = reader.pop_chunks() (applied,) = reader.pop_chunks()
return Tensor(applied) return Tensor(applied)
def stream(self, waveform: Tensor, sample_rate: int, frames_per_chunk: int) -> Iterator[Tensor]: def stream(
self, waveform: Tensor, sample_rate: int, frames_per_chunk: int, output_sample_rate: Optional[int] = None
) -> Iterator[Tensor]:
"""Apply the effect and/or codecs to the given tensor chunk by chunk. """Apply the effect and/or codecs to the given tensor chunk by chunk.
Args: Args:
waveform (Tensor): The input waveform. Shape: ``(time, channel)`` waveform (Tensor): The input waveform. Shape: ``(time, channel)``
sample_rate (int): Sample rate of the waveform. sample_rate (int): Sample rate of the waveform.
frames_per_chunk (int): The number of frames to return at a time. frames_per_chunk (int): The number of frames to return at a time.
output_sample_rate (int or None, optional): Output sample rate.
If provided, override the output sample rate.
Otherwise, the resulting tensor is resampled to have
the same sample rate as the input.
Default: ``None``.
Returns: Returns:
Iterator[Tensor]: Iterator[Tensor]:
...@@ -334,6 +347,6 @@ class AudioEffector: ...@@ -334,6 +347,6 @@ class AudioEffector:
if waveform.numel() == 0: if waveform.numel() == 0:
return waveform return waveform
reader = self._get_reader(waveform, sample_rate, frames_per_chunk) reader = self._get_reader(waveform, sample_rate, output_sample_rate, frames_per_chunk)
for (applied,) in reader.stream(): for (applied,) in reader.stream():
yield Tensor(applied) yield Tensor(applied)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment