Unverified Commit e0f4c0ec authored by moto's avatar moto Committed by GitHub
Browse files

Fix SignalInfo member name to frame (#734)

This PR fixes the wrong member name of SignalInfo introduced in #718. 

 - `num_samples` == `num_frames` * `num_channels`.
parent 7427bf56
...@@ -35,7 +35,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -35,7 +35,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
...@@ -55,7 +55,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -55,7 +55,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
...@@ -74,7 +74,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -74,7 +74,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.get_sample_rate() == sample_rate
# mp3 does not preserve the number of samples # mp3 does not preserve the number of samples
# assert info.get_num_samples() == sample_rate * duration # assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
...@@ -92,7 +92,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -92,7 +92,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
...@@ -110,5 +110,5 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -110,5 +110,5 @@ class TestInfo(TempDirMixin, PytorchTestCase):
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.get_num_channels() == num_channels
...@@ -44,5 +44,5 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -44,5 +44,5 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
ts_info = ts_info_func(audio_path) ts_info = ts_info_func(audio_path)
assert py_info.get_sample_rate() == ts_info.get_sample_rate() assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples() assert py_info.get_num_frames() == ts_info.get_num_frames()
assert py_info.get_num_channels() == ts_info.get_num_channels() assert py_info.get_num_channels() == ts_info.get_num_channels()
...@@ -12,7 +12,7 @@ static auto registerSignalInfo = ...@@ -12,7 +12,7 @@ static auto registerSignalInfo =
.def(torch::init<int64_t, int64_t, int64_t>()) .def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate) .def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels) .def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples); .def("get_num_frames", &SignalInfo::getNumFrames);
static auto registerGetInfo = torch::RegisterOperators().op( static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options() torch::RegisterOperators::options()
......
...@@ -4,10 +4,10 @@ namespace torchaudio { ...@@ -4,10 +4,10 @@ namespace torchaudio {
SignalInfo::SignalInfo( SignalInfo::SignalInfo(
const int64_t sample_rate_, const int64_t sample_rate_,
const int64_t num_channels_, const int64_t num_channels_,
const int64_t num_samples_) const int64_t num_frames_)
: sample_rate(sample_rate_), : sample_rate(sample_rate_),
num_channels(num_channels_), num_channels(num_channels_),
num_samples(num_samples_){}; num_frames(num_frames_){};
int64_t SignalInfo::getSampleRate() const { int64_t SignalInfo::getSampleRate() const {
return sample_rate; return sample_rate;
...@@ -17,7 +17,7 @@ int64_t SignalInfo::getNumChannels() const { ...@@ -17,7 +17,7 @@ int64_t SignalInfo::getNumChannels() const {
return num_channels; return num_channels;
} }
int64_t SignalInfo::getNumSamples() const { int64_t SignalInfo::getNumFrames() const {
return num_samples; return num_frames;
} }
} // namespace torchaudio } // namespace torchaudio
...@@ -7,15 +7,15 @@ namespace torchaudio { ...@@ -7,15 +7,15 @@ namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder { struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate; int64_t sample_rate;
int64_t num_channels; int64_t num_channels;
int64_t num_samples; int64_t num_frames;
SignalInfo( SignalInfo(
const int64_t sample_rate_, const int64_t sample_rate_,
const int64_t num_channels_, const int64_t num_channels_,
const int64_t num_samples_); const int64_t num_frames_);
int64_t getSampleRate() const; int64_t getSampleRate() const;
int64_t getNumChannels() const; int64_t getNumChannels() const;
int64_t getNumSamples() const; int64_t getNumFrames() const;
}; };
} // namespace torchaudio } // namespace torchaudio
......
...@@ -30,10 +30,10 @@ def _init_dummy_module(): ...@@ -30,10 +30,10 @@ def _init_dummy_module():
without extension. without extension.
This class has to implement the same interface as C++ equivalent. This class has to implement the same interface as C++ equivalent.
""" """
def __init__(self, sample_rate: int, num_channels: int, num_samples: int): def __init__(self, sample_rate: int, num_channels: int, num_frames: int):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.num_channels = num_channels self.num_channels = num_channels
self.num_samples = num_samples self.num_frames = num_frames
def get_sample_rate(self): def get_sample_rate(self):
return self.sample_rate return self.sample_rate
...@@ -41,8 +41,8 @@ def _init_dummy_module(): ...@@ -41,8 +41,8 @@ def _init_dummy_module():
def get_num_channels(self): def get_num_channels(self):
return self.num_channels return self.num_channels
def get_num_samples(self): def get_num_frames(self):
return self.num_samples return self.num_frames
DummyModule = namedtuple('torchaudio', ['SignalInfo']) DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo) module = DummyModule(SignalInfo)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment