Commit 4bc4ca75 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Support changing the number of channels in StreamReader (#3216)

Summary:
This commit adds `num_channels` argument,
which allows one to change the number of channels on-the-fly.

Pull Request resolved: https://github.com/pytorch/audio/pull/3216

Reviewed By: hwangjeff

Differential Revision: D44516925

Pulled By: mthrok

fbshipit-source-id: 3e5a11b3fdbb19071f712a8148e27aff60341df3
parent 09ccf7cc
...@@ -6,6 +6,7 @@ from parameterized import parameterized, parameterized_class ...@@ -6,6 +6,7 @@ from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
get_image, get_image,
get_sinusoid,
get_wav_data, get_wav_data,
is_ffmpeg_available, is_ffmpeg_available,
nested_params, nested_params,
...@@ -836,6 +837,7 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -836,6 +837,7 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
self._test_wav(src, original, fmt=None) self._test_wav(src, original, fmt=None)
def test_audio_stream_format(self): def test_audio_stream_format(self):
"`format` argument properly changes the sample format of decoded audio"
num_channels = 2 num_channels = 2
src, s32 = self.get_src(8000, dtype="int32", num_channels=num_channels) src, s32 = self.get_src(8000, dtype="int32", num_channels=num_channels)
args = { args = {
...@@ -878,6 +880,40 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -878,6 +880,40 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
self.assertEqual(chunks[10], f64) self.assertEqual(chunks[10], f64)
self.assertEqual(chunks[11], f64) self.assertEqual(chunks[11], f64)
@nested_params([4000, 16000])
def test_basic_audio_stream_sample_rate(self, sr):
"""`sample_rate` argument changes the sample_rate of decoded audio"""
src_num_channels, src_sr = 2, 8000
data = get_sinusoid(sample_rate=src_sr, n_channels=src_num_channels, channels_first=False)
path = self.get_temp_path("ref.wav")
save_wav(path, data, src_sr, channels_first=False)
s = StreamReader(path)
s.add_basic_audio_stream(frames_per_chunk=-1, format="flt", sample_rate=sr)
self.assertEqual(s.get_src_stream_info(0).sample_rate, src_sr)
self.assertEqual(s.get_out_stream_info(0).sample_rate, sr)
s.process_all_packets()
(chunks,) = s.pop_chunks()
self.assertEqual(chunks.shape, [sr, src_num_channels])
@nested_params([1, 2, 3, 8, 16])
def test_basic_audio_stream_num_channels(self, num_channels):
"""`sample_rate` argument changes the number of channels of decoded audio"""
src_num_channels, sr = 2, 8000
data = get_sinusoid(sample_rate=sr, n_channels=src_num_channels, channels_first=False)
path = self.get_temp_path("ref.wav")
save_wav(path, data, sr, channels_first=False)
s = StreamReader(path)
s.add_basic_audio_stream(frames_per_chunk=-1, format="flt", num_channels=num_channels)
self.assertEqual(s.get_src_stream_info(0).num_channels, src_num_channels)
self.assertEqual(s.get_out_stream_info(0).num_channels, num_channels)
s.process_all_packets()
(chunks,) = s.pop_chunks()
self.assertEqual(chunks.shape, [sr, num_channels])
@nested_params( @nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"] ["int16", "uint8", "int32"], # "float", "double", "int64"]
[1, 2, 4, 8], [1, 2, 4, 8],
......
...@@ -230,12 +230,17 @@ def _parse_oi(i): ...@@ -230,12 +230,17 @@ def _parse_oi(i):
raise ValueError(f"Unexpected media_type: {i.media_type}({i})") raise ValueError(f"Unexpected media_type: {i.media_type}({i})")
def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]): def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str], num_channels: Optional[int]):
descs = [] descs = []
if sample_rate is not None: if sample_rate is not None:
descs.append(f"aresample={sample_rate}") descs.append(f"aresample={sample_rate}")
if fmt is not None: if fmt is not None or num_channels is not None:
descs.append(f"aformat=sample_fmts={fmt}") parts = []
if fmt is not None:
parts.append(f"sample_fmts={fmt}")
if num_channels is not None:
parts.append(f"channel_layouts={num_channels}c")
descs.append(f"aformat={':'.join(parts)}")
return ",".join(descs) if descs else None return ",".join(descs) if descs else None
...@@ -630,6 +635,7 @@ class StreamReader: ...@@ -630,6 +635,7 @@ class StreamReader:
decoder_option: Optional[Dict[str, str]] = None, decoder_option: Optional[Dict[str, str]] = None,
format: Optional[str] = "fltp", format: Optional[str] = "fltp",
sample_rate: Optional[int] = None, sample_rate: Optional[int] = None,
num_channels: Optional[int] = None,
): ):
"""Add output audio stream """Add output audio stream
...@@ -662,6 +668,8 @@ class StreamReader: ...@@ -662,6 +668,8 @@ class StreamReader:
Default: ``"fltp"``. Default: ``"fltp"``.
sample_rate (int or None, optional): If provided, resample the audio. sample_rate (int or None, optional): If provided, resample the audio.
num_channels (int, or None, optional): If provided, change the number of channels.
""" """
self.add_audio_stream( self.add_audio_stream(
frames_per_chunk, frames_per_chunk,
...@@ -669,7 +677,7 @@ class StreamReader: ...@@ -669,7 +677,7 @@ class StreamReader:
stream_index, stream_index,
decoder, decoder,
decoder_option, decoder_option,
_get_afilter_desc(sample_rate, format), _get_afilter_desc(sample_rate, format, num_channels),
) )
@_format_video_args @_format_video_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment