Unverified Commit e0f4c0ec authored by moto's avatar moto Committed by GitHub
Browse files

Fix SignalInfo member name to frame (#734)

This PR fixes the wrong member name of SignalInfo introduced in #718. 

 - `num_samples` == `num_frames` * `num_channels`.
parent 7427bf56
......@@ -35,7 +35,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product(
......@@ -55,7 +55,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product(
......@@ -74,7 +74,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
# mp3 does not preserve the number of samples
# assert info.get_num_samples() == sample_rate * duration
# assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product(
......@@ -92,7 +92,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product(
......@@ -110,5 +110,5 @@ class TestInfo(TempDirMixin, PytorchTestCase):
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
......@@ -44,5 +44,5 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
ts_info = ts_info_func(audio_path)
assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples()
assert py_info.get_num_frames() == ts_info.get_num_frames()
assert py_info.get_num_channels() == ts_info.get_num_channels()
......@@ -12,7 +12,7 @@ static auto registerSignalInfo =
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples);
.def("get_num_frames", &SignalInfo::getNumFrames);
static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options()
......
......@@ -4,10 +4,10 @@ namespace torchaudio {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_)
const int64_t num_frames_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_samples(num_samples_){};
num_frames(num_frames_){};
int64_t SignalInfo::getSampleRate() const {
return sample_rate;
......@@ -17,7 +17,7 @@ int64_t SignalInfo::getNumChannels() const {
return num_channels;
}
int64_t SignalInfo::getNumSamples() const {
return num_samples;
int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
} // namespace torchaudio
......@@ -7,15 +7,15 @@ namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_samples;
int64_t num_frames;
SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_);
const int64_t num_frames_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumSamples() const;
int64_t getNumFrames() const;
};
} // namespace torchaudio
......
......@@ -30,10 +30,10 @@ def _init_dummy_module():
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_samples: int):
def __init__(self, sample_rate: int, num_channels: int, num_frames: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_samples = num_samples
self.num_frames = num_frames
def get_sample_rate(self):
return self.sample_rate
......@@ -41,8 +41,8 @@ def _init_dummy_module():
def get_num_channels(self):
return self.num_channels
def get_num_samples(self):
return self.num_samples
def get_num_frames(self):
return self.num_frames
DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment