Commit 502d5811 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Raise an error is StreamWriter is not opened (#3152)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3152

In StreamWriter, if the destination is not opened when attempting to write data, it causes segmentation fault.
This commit adds guard so that instead of segfault, it will error-out.

Reviewed By: nateanl

Differential Revision: D43852649

fbshipit-source-id: aef5db7c1508f8a7db5834c2ab6de3cad09f9d60
parent cea12eaf
...@@ -92,6 +92,22 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -92,6 +92,22 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
with open(self.get_temp_path(path), "rb") as fileobj: with open(self.get_temp_path(path), "rb") as fileobj:
return fileobj.read() return fileobj.read()
def test_unopened_error(self):
"""If dst is not opened when attempting to write data, runtime error should be raised"""
path = self.get_dst("test.mp4")
s = StreamWriter(path, format="mp4")
s.set_metadata(metadata={"artist": "torchaudio", "title": self.id()})
s.add_audio_stream(sample_rate=16000, num_channels=2)
s.add_video_stream(frame_rate=30, width=16, height=16)
dummy = torch.zeros((3, 2))
with self.assertRaises(RuntimeError):
s.write_audio_chunk(0, dummy)
dummy = torch.zeros((3, 3, 16, 16))
with self.assertRaises(RuntimeError):
s.write_video_chunk(1, dummy)
@skipIfNoModule("tinytag") @skipIfNoModule("tinytag")
def test_metadata_overwrite(self): def test_metadata_overwrite(self):
"""When set_metadata is called multiple times, only entries from the last call are saved""" """When set_metadata is called multiple times, only entries from the last call are saved"""
......
...@@ -180,6 +180,7 @@ void StreamWriter::open(const c10::optional<OptionDict>& option) { ...@@ -180,6 +180,7 @@ void StreamWriter::open(const c10::optional<OptionDict>& option) {
" (", " (",
av_err2string(ret), av_err2string(ret),
")"); ")");
is_open = true;
} }
void StreamWriter::close() { void StreamWriter::close() {
...@@ -196,9 +197,11 @@ void StreamWriter::close() { ...@@ -196,9 +197,11 @@ void StreamWriter::close() {
// avio_closep can be only applied to AVIOContext opened by avio_open // avio_closep can be only applied to AVIOContext opened by avio_open
avio_closep(&(pFormatContext->pb)); avio_closep(&(pFormatContext->pb));
} }
is_open = false;
} }
void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) { void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK( TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()), 0 <= i && i < static_cast<int>(processes.size()),
"Invalid stream index. Index must be in range of [0, ", "Invalid stream index. Index must be in range of [0, ",
...@@ -209,6 +212,7 @@ void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) { ...@@ -209,6 +212,7 @@ void StreamWriter::write_audio_chunk(int i, const torch::Tensor& waveform) {
} }
void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) { void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK( TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()), 0 <= i && i < static_cast<int>(processes.size()),
"Invalid stream index. Index must be in range of [0, ", "Invalid stream index. Index must be in range of [0, ",
...@@ -219,6 +223,7 @@ void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) { ...@@ -219,6 +223,7 @@ void StreamWriter::write_video_chunk(int i, const torch::Tensor& frames) {
} }
void StreamWriter::flush() { void StreamWriter::flush() {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
for (auto& p : processes) { for (auto& p : processes) {
p.flush(); p.flush();
} }
......
...@@ -16,6 +16,7 @@ class StreamWriter { ...@@ -16,6 +16,7 @@ class StreamWriter {
AVBufferRefPtr pHWBufferRef; AVBufferRefPtr pHWBufferRef;
std::vector<EncodeProcess> processes; std::vector<EncodeProcess> processes;
AVPacketPtr pkt; AVPacketPtr pkt;
bool is_open = false;
protected: protected:
/// @cond /// @cond
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment