Unverified Commit 47d00080 authored by Jason Hurt's avatar Jason Hurt Committed by GitHub
Browse files

Reject saving GSM when not compatible (#1384)

parent 6d81ab8b
......@@ -305,7 +305,15 @@ class SaveTest(SaveTestBase):
)
def test_save_gsm(self, test_mode):
self.assert_save_consistency(
"gsm", test_mode=test_mode)
"gsm", num_channels=1, test_mode=test_mode)
with self.assertRaises(
RuntimeError, msg="gsm format only supports single channel audio."):
self.assert_save_consistency(
"gsm", num_channels=2, test_mode=test_mode)
with self.assertRaises(
RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
self.assert_save_consistency(
"gsm", sample_rate=16000, test_mode=test_mode)
@parameterized.expand([
("wav", "PCM_S", 16),
......
......@@ -101,6 +101,13 @@ void save_audio_file(
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "htk format only supports single channel audio.");
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "gsm format only supports single channel audio.");
TORCH_CHECK(
sample_rate == 8000,
"gsm format only supports a sampling rate of 8kHz.");
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
......@@ -243,6 +250,16 @@ void save_audio_fileobj(
throw std::runtime_error(
"htk format only supports single channel audio.");
}
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"gsm format only supports single channel audio.");
}
if (sample_rate != 8000) {
throw std::runtime_error(
"gsm format only supports a sampling rate of 8kHz.");
}
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment