Commit 7729723b authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Modify `info_audio` to compute and return number of frames if not found in stream info (#2740)

Summary:
Modifies `info_audio` to compute and return number of frames if not found in stream info. This resolves the `num_frames == 0` issue for mp3 that's cited in https://github.com/pytorch/audio/issues/2524.

Pull Request resolved: https://github.com/pytorch/audio/pull/2740

Reviewed By: nateanl

Differential Revision: D40168639

Pulled By: nateanl

fbshipit-source-id: bb45baa0f9cd56844315b04e40ab9835d825fc24
parent 1a18c41d
...@@ -324,15 +324,15 @@ class TestLoadWithoutExtension(PytorchTestCase): ...@@ -324,15 +324,15 @@ class TestLoadWithoutExtension(PytorchTestCase):
path = get_asset_path("mp3_without_ext") path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path) sinfo = sox_io_backend.info(path)
assert sinfo.sample_rate == 16000 assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 0 assert sinfo.num_frames == 80000
assert sinfo.num_channels == 1 assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert sinfo.encoding == "MP3" assert sinfo.encoding == "MP3"
with open(path, "rb") as fileobj: with open(path, "rb") as fileobj:
sinfo = sox_io_backend.info(fileobj) sinfo = sox_io_backend.info(fileobj, format="mp3")
assert sinfo.sample_rate == 16000 assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 0 assert sinfo.num_frames == 80000
assert sinfo.num_channels == 1 assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 assert sinfo.bits_per_sample == 0
assert sinfo.encoding == "MP3" assert sinfo.encoding == "MP3"
...@@ -427,7 +427,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -427,7 +427,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -457,7 +457,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -457,7 +457,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
with self._set_buffer_size(16384): with self._set_buffer_size(16384):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames num_frames = 0 if ext in ["vorbis"] else num_frames
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -485,7 +485,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -485,7 +485,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -513,7 +513,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -513,7 +513,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames num_frames = {"vorbis": 0, "mp3": 1728}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -541,7 +541,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -541,7 +541,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
...@@ -583,7 +583,7 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase): ...@@ -583,7 +583,7 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames) sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype) bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ["mp3", "vorbis"] else num_frames num_frames = {"vorbis": 0, "mp3": 49536}.get(ext, num_frames)
assert sinfo.sample_rate == sample_rate assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels assert sinfo.num_channels == num_channels
......
...@@ -11,9 +11,14 @@ def _info_audio( ...@@ -11,9 +11,14 @@ def _info_audio(
): ):
i = s.find_best_audio_stream() i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i) sinfo = s.get_src_stream_info(i)
if sinfo[5] == 0:
waveform, _ = _load_audio(s)
num_frames = waveform.size(1)
else:
num_frames = sinfo[5]
return AudioMetaData( return AudioMetaData(
int(sinfo[8]), int(sinfo[8]),
sinfo[5], num_frames,
sinfo[9], sinfo[9],
sinfo[6], sinfo[6],
sinfo[1].upper(), sinfo[1].upper(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment