Commit 502d5811 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Raise an error is StreamWriter is not opened (#3152)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3152

In StreamWriter, if the destination is not opened when attempting to write data, it causes segmentation fault.
This commit adds guard so that instead of segfault, it will error-out.

Reviewed By: nateanl

Differential Revision: D43852649

fbshipit-source-id: aef5db7c1508f8a7db5834c2ab6de3cad09f9d60
parent cea12eaf
......@@ -92,6 +92,22 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
with open(self.get_temp_path(path), "rb") as fileobj:
return fileobj.read()
def test_unopened_error(self):
"""If dst is not opened when attempting to write data, runtime error should be raised"""
path = self.get_dst("test.mp4")
s = StreamWriter(path, format="mp4")
s.set_metadata(metadata={"artist": "torchaudio", "title": self.id()})
s.add_audio_stream(sample_rate=16000, num_channels=2)
s.add_video_stream(frame_rate=30, width=16, height=16)
dummy = torch.zeros((3, 2))
with self.assertRaises(RuntimeError):
s.write_audio_chunk(0, dummy)
dummy = torch.zeros((3, 3, 16, 16))
with self.assertRaises(RuntimeError):
s.write_video_chunk(1, dummy)
@skipIfNoModule("tinytag")
def test_metadata_overwrite(self):
"""When set_metadata is called multiple times, only entries from the last call are saved"""
......
......@@ -180,6 +180,7 @@ void StreamWriter::open(const c10::optional<OptionDict>& option) {
" (",
av_err2string(ret),
")");
is_open = true;
}
void StreamWriter::close() {
......@@ -196,9 +197,11 @@ void StreamWriter::close() {
// avio_closep can be only applied to AVIOContext opened by avio_open
avio_closep(&(pFormatContext->pb));
}
is_open = false;
}
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
"Invalid stream index. Index must be in range of [0, ",
......@@ -209,6 +212,7 @@ void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
}
void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
"Invalid stream index. Index must be in range of [0, ",
......@@ -219,6 +223,7 @@ void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
}
void StreamWriter::flush() {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
for (auto& p : processes) {
p.flush();
}
......
......@@ -16,6 +16,7 @@ class StreamWriter {
AVBufferRefPtr pHWBufferRef;
std::vector<EncodeProcess> processes;
AVPacketPtr pkt;
bool is_open = false;
protected:
/// @cond
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment