Commit 7729723b authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Modify `info_audio` to compute and return number of frames if not found in stream info (#2740)

Summary:
Modifies `info_audio` to compute and return number of frames if not found in stream info. This resolves the `num_frames == 0` issue for mp3 that's cited in https://github.com/pytorch/audio/issues/2524.

Pull Request resolved: https://github.com/pytorch/audio/pull/2740

Reviewed By: nateanl

Differential Revision: D40168639

Pulled By: nateanl

fbshipit-source-id: bb45baa0f9cd56844315b04e40ab9835d825fc24
parent 1a18c41d
......@@ -324,15 +324,15 @@ class TestLoadWithoutExtension(PytorchTestCase):
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path)
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 0
assert sinfo.num_frames == 80000
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert sinfo.encoding == "MP3"
with open(path, "rb") as fileobj:
sinfo = sox_io_backend.info(fileobj)
sinfo = sox_io_backend.info(fileobj, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 0
assert sinfo.num_frames == 80000
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0
assert sinfo.encoding == "MP3"
......@@ -427,7 +427,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -457,7 +457,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
with self._set_buffer_size(16384):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
num_frames = 0 if ext in ["vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -485,7 +485,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -513,7 +513,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
num_frames = {"vorbis": 0, "mp3": 1728}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -541,7 +541,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......@@ -583,7 +583,7 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames
num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
......
......@@ -11,9 +11,14 @@ def _info_audio(
):
i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i)
if sinfo[5] == 0:
waveform, _ = _load_audio(s)
num_frames = waveform.size(1)
else:
num_frames = sinfo[5]
return AudioMetaData(
int(sinfo[8]),
sinfo[5],
num_frames,
sinfo[9],
sinfo[6],
sinfo[1].upper(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment