Commit d69e8857 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fix path-like object support in FFmpeg dispatcher (#3243)

Summary:
In dispatcher mode, FFmpeg backend does not handle file-like object, and C++ implementation raises an issue.

This commit fixes it by normalizing file-like object to string.

Pull Request resolved: https://github.com/pytorch/audio/pull/3243

Reviewed By: nateanl

Differential Revision: D44719280

Pulled By: mthrok

fbshipit-source-id: 9dae459e2a5fb4992b4ef53fe4829fe8c35b2edd
parent 5053aa7f
import io import io
import itertools import itertools
import os import os
import pathlib
import tarfile import tarfile
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
...@@ -35,6 +36,24 @@ if _mod_utils.is_module_available("requests"): ...@@ -35,6 +36,24 @@ if _mod_utils.is_module_available("requests"):
class TestInfo(TempDirMixin, PytorchTestCase): class TestInfo(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg") _info = partial(get_info_func(), backend="ffmpeg")
def test_pathlike(self):
"""FFmpeg dispatcher can query audio data from pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = self._info(pathlib.Path(path))
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding("wav", dtype)
@parameterized.expand( @parameterized.expand(
list( list(
itertools.product( itertools.product(
......
import io import io
import itertools import itertools
import pathlib
import tarfile import tarfile
from functools import partial from functools import partial
...@@ -125,6 +126,21 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -125,6 +126,21 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
class TestLoad(LoadTestBase): class TestLoad(LoadTestBase):
"""Test the correctness of `self._load` for various formats""" """Test the correctness of `self._load` for various formats"""
def test_pathlike(self):
"""FFmpeg dispatcher can load waveform from pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
waveform, sr = self._load(pathlib.Path(path))
self.assertEqual(sr, sample_rate)
self.assertEqual(waveform, data)
@parameterized.expand( @parameterized.expand(
list( list(
itertools.product( itertools.product(
......
import io import io
import os import os
import pathlib
import subprocess import subprocess
import sys import sys
from functools import partial from functools import partial
...@@ -146,6 +147,17 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase): ...@@ -146,6 +147,17 @@ class SaveTestBase(TempDirMixin, TorchaudioTestCase):
@skipIfNoExec("ffmpeg") @skipIfNoExec("ffmpeg")
@skipIfNoFFmpeg @skipIfNoFFmpeg
class SaveTest(SaveTestBase): class SaveTest(SaveTestBase):
def test_pathlike(self):
"""FFmpeg dispatcher can save audio data to pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
self._save(pathlib.Path(path), data, sample_rate)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
[ [
......
...@@ -82,7 +82,7 @@ class FFmpegBackend(Backend): ...@@ -82,7 +82,7 @@ class FFmpegBackend(Backend):
if hasattr(uri, "read"): if hasattr(uri, "read"):
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size) metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
else: else:
metadata = info_audio(uri, format) metadata = info_audio(os.path.normpath(uri), format)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample) metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding) metadata.encoding = _map_encoding(metadata.encoding)
return metadata return metadata
...@@ -108,7 +108,7 @@ class FFmpegBackend(Backend): ...@@ -108,7 +108,7 @@ class FFmpegBackend(Backend):
buffer_size, buffer_size,
) )
else: else:
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format) return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)
@staticmethod @staticmethod
def save( def save(
...@@ -122,7 +122,7 @@ class FFmpegBackend(Backend): ...@@ -122,7 +122,7 @@ class FFmpegBackend(Backend):
buffer_size: int = 4096, buffer_size: int = 4096,
) -> None: ) -> None:
save_audio( save_audio(
uri, os.path.normpath(uri),
src, src,
sample_rate, sample_rate,
channels_first, channels_first,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment