Commit 09ccf7cc authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Reduce io tests (#3217)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3217

This commit removes some tests for file-like object from StreamWriter test.

The rational is that testing things after the output file is opened are
same for file-like object and regular files. Things like filter-graph and
encoder format change does not affect how the encoded bynary are written.

Reviewed By: hwangjeff

Differential Revision: D44518626

fbshipit-source-id: 821ec20deca92e5e5c85bf4d47997eed51735374
parent c76fd58b
......@@ -91,10 +91,6 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
def get_dst(self, path):
return super().get_dst(self.get_temp_path(path))
def get_buf(self, path):
with open(self.get_temp_path(path), "rb") as fileobj:
return fileobj.read()
def test_unopened_error(self):
"""If dst is not opened when attempting to write data, runtime error should be raised"""
path = self.get_dst("test.mp4")
......@@ -227,6 +223,19 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
s.write_audio_chunk(0, audio)
s.write_video_chunk(1, video)
@skipIfNoFFmpeg
class StreamWriterCorrectnessTest(TempDirMixin, TorchaudioTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
torchaudio.utils.ffmpeg_utils.set_log_level(32)
@classmethod
def tearDownClass(cls):
torchaudio.utils.ffmpeg_utils.set_log_level(8)
super().tearDownClass()
@nested_params(
[
("gray8", "gray8"),
......@@ -252,16 +261,16 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
chunk = torch.randint(low=0, high=255, size=src_size, dtype=torch.uint8)
# Write data
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
s = StreamWriter(dst, format="rawvideo")
s.add_video_stream(frame_rate, width, height, format=src_fmt, encoder_format=encoder_fmt)
with s.open():
s.write_video_chunk(0, chunk)
# Fetch the written data
if self.test_fileobj:
dst.flush()
buf = self.get_buf(filename)
with open(dst, "rb") as fileobj:
buf = fileobj.read()
result = torch.frombuffer(buf, dtype=torch.uint8)
if encoder_fmt.endswith("p"):
result = result.reshape(src_size)
......@@ -286,14 +295,12 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
h, w = resolution
# Write data
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_video_stream(frame_rate=framerate, height=h, width=w, format=format)
chunk = torch.stack([torch.full((3, h, w), i, dtype=torch.uint8) for i in torch.linspace(0, 255, 256)])
with s.open():
s.write_video_chunk(0, chunk)
if self.test_fileobj:
dst.flush()
# Load data
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
......@@ -329,15 +336,12 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, dtype="int16", channels_first=False)
# Write data
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, format="s16")
with s.open():
s.write_audio_chunk(0, data)
if self.test_fileobj:
dst.flush()
# Load data
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
s.add_audio_stream(-1)
......@@ -364,15 +368,12 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, channels_first=False)
# Write data
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
with s.open():
s.write_audio_chunk(0, data)
if self.test_fileobj:
dst.flush()
# Load data
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
s.add_audio_stream(-1)
......@@ -406,16 +407,13 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
width, height = 96, 128
# Write data
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
video = torch.randint(256, (90, 3, height, width), dtype=torch.uint8)
with writer.open():
writer.write_video_chunk(0, video)
if self.test_fileobj:
dst.flush()
# Load data
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
assert reader.get_src_stream_info(0).frame_rate == frame_rate
......@@ -430,7 +428,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
width, height = 96, 128
# Write data
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
......@@ -438,9 +436,6 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
with writer.open():
writer.write_video_chunk(0, video)
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(1)
pts = [chunk.pts for (chunk,) in reader.stream()]
......@@ -459,7 +454,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
num_channels = 2
# Write data
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
......@@ -468,9 +463,6 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
with writer.open():
writer.write_audio_chunk(0, audio)
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
frames_per_chunk = sample_rate // 4
reader.add_audio_stream(frames_per_chunk, -1)
......@@ -509,7 +501,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
width, height = 8, 8
# Write data
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
......@@ -521,10 +513,6 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
reference_pts.append(pts)
writer.write_video_chunk(0, video, pts)
# check
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(1)
pts = [chunk.pts for (chunk,) in reader.stream()]
......@@ -544,7 +532,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
num_channels = 2
# Write data
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
codec_config = torchaudio.io.StreamWriter.CodecConfig(bit_rate=198_000, compression_level=3)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, codec_config=codec_config)
......@@ -590,17 +578,13 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
original = get_audio_chunk("s16", num_channels=num_channels, sample_rate=sample_rate)
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_audio_stream(sample_rate=8000, num_channels=num_channels, filter_desc="areverse", format="s16")
with w.open():
w.write_audio_chunk(0, original)
# check
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_audio_stream(-1)
reader.process_all_packets()
......@@ -617,17 +601,13 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
original = torch.zeros((num_frames, 3, height, width), dtype=torch.uint8)
dst = self.get_dst(filename)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_video_stream(frame_rate=rate, format="rgb24", height=height, width=width, filter_desc="framestep=2")
with w.open():
w.write_video_chunk(0, original)
# check
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(-1)
reader.process_all_packets()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment