Commit 4bc4ca75 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Support changing the number of channels in StreamReader (#3216)

Summary:
This commit adds `num_channels` argument,
which allows one to change the number of channels on-the-fly.

Pull Request resolved: https://github.com/pytorch/audio/pull/3216

Reviewed By: hwangjeff

Differential Revision: D44516925

Pulled By: mthrok

fbshipit-source-id: 3e5a11b3fdbb19071f712a8148e27aff60341df3
parent 09ccf7cc
......@@ -6,6 +6,7 @@ from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import (
get_asset_path,
get_image,
get_sinusoid,
get_wav_data,
is_ffmpeg_available,
nested_params,
......@@ -836,6 +837,7 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
self._test_wav(src, original, fmt=None)
def test_audio_stream_format(self):
"`format` argument properly changes the sample format of decoded audio"
num_channels = 2
src, s32 = self.get_src(8000, dtype="int32", num_channels=num_channels)
args = {
......@@ -878,6 +880,40 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
self.assertEqual(chunks[10], f64)
self.assertEqual(chunks[11], f64)
@nested_params([4000, 16000])
def test_basic_audio_stream_sample_rate(self, sr):
"""`sample_rate` argument changes the sample_rate of decoded audio"""
src_num_channels, src_sr = 2, 8000
data = get_sinusoid(sample_rate=src_sr, n_channels=src_num_channels, channels_first=False)
path = self.get_temp_path("ref.wav")
save_wav(path, data, src_sr, channels_first=False)
s = StreamReader(path)
s.add_basic_audio_stream(frames_per_chunk=-1, format="flt", sample_rate=sr)
self.assertEqual(s.get_src_stream_info(0).sample_rate, src_sr)
self.assertEqual(s.get_out_stream_info(0).sample_rate, sr)
s.process_all_packets()
(chunks,) = s.pop_chunks()
self.assertEqual(chunks.shape, [sr, src_num_channels])
@nested_params([1, 2, 3, 8, 16])
def test_basic_audio_stream_num_channels(self, num_channels):
"""`sample_rate` argument changes the number of channels of decoded audio"""
src_num_channels, sr = 2, 8000
data = get_sinusoid(sample_rate=sr, n_channels=src_num_channels, channels_first=False)
path = self.get_temp_path("ref.wav")
save_wav(path, data, sr, channels_first=False)
s = StreamReader(path)
s.add_basic_audio_stream(frames_per_chunk=-1, format="flt", num_channels=num_channels)
self.assertEqual(s.get_src_stream_info(0).num_channels, src_num_channels)
self.assertEqual(s.get_out_stream_info(0).num_channels, num_channels)
s.process_all_packets()
(chunks,) = s.pop_chunks()
self.assertEqual(chunks.shape, [sr, num_channels])
@nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"]
[1, 2, 4, 8],
......
......@@ -230,12 +230,17 @@ def _parse_oi(i):
raise ValueError(f"Unexpected media_type: {i.media_type}({i})")
def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]):
def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str], num_channels: Optional[int]):
descs = []
if sample_rate is not None:
descs.append(f"aresample={sample_rate}")
if fmt is not None or num_channels is not None:
parts = []
if fmt is not None:
descs.append(f"aformat=sample_fmts={fmt}")
parts.append(f"sample_fmts={fmt}")
if num_channels is not None:
parts.append(f"channel_layouts={num_channels}c")
descs.append(f"aformat={':'.join(parts)}")
return ",".join(descs) if descs else None
......@@ -630,6 +635,7 @@ class StreamReader:
decoder_option: Optional[Dict[str, str]] = None,
format: Optional[str] = "fltp",
sample_rate: Optional[int] = None,
num_channels: Optional[int] = None,
):
"""Add output audio stream
......@@ -662,6 +668,8 @@ class StreamReader:
Default: ``"fltp"``.
sample_rate (int or None, optional): If provided, resample the audio.
num_channels (int, or None, optional): If provided, change the number of channels.
"""
self.add_audio_stream(
frames_per_chunk,
......@@ -669,7 +677,7 @@ class StreamReader:
stream_index,
decoder,
decoder_option,
_get_afilter_desc(sample_rate, format),
_get_afilter_desc(sample_rate, format, num_channels),
)
@_format_video_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment