You need to sign in or sign up before continuing.
Commit 5bbbb1d5 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

[BC-Breaking] Remove compute_kaldi_pitch (#3368)

Summary:
This commit removes compute_kaldi_pitch function and the underlying Kaldi integration from torchaudio.

Kaldi pitch function was added in a short period of time by integrating the original Kaldi implementation, instead of reimplementing it in PyTorch.

The Kaldi integration employed a hack which replaces the base vector/matrix implementation of Kaldi with PyTorch Tensor so that there is only one blas library within torchaudio.

Recently, we are making torchaudio more lean, and we don't see a wide adoption of kaldi_pitch feature, so we decided to remove them.

See some of the discussion https://github.com/pytorch/audio/issues/1269

Pull Request resolved: https://github.com/pytorch/audio/pull/3368

Differential Revision: D46406176

Pulled By: mthrok

fbshipit-source-id: ee5e24d825188f379979ddccd680c7323b119b1e
parent 2ba36b47
...@@ -38,7 +38,7 @@ fi ...@@ -38,7 +38,7 @@ fi
printf "\x1b[34mRunning clang-format:\x1b[0m\n" printf "\x1b[34mRunning clang-format:\x1b[0m\n"
"${this_dir}"/run_clang_format.py \ "${this_dir}"/run_clang_format.py \
-r torchaudio/csrc third_party/kaldi/src \ -r torchaudio/csrc \
--clang-format-executable "${clangformat_path}" \ --clang-format-executable "${clangformat_path}" \
&& git diff --exit-code && git diff --exit-code
status=$? status=$?
......
[submodule "kaldi"]
path = third_party/kaldi/submodule
url = https://github.com/kaldi-asr/kaldi
ignore = dirty
...@@ -53,7 +53,6 @@ endif() ...@@ -53,7 +53,6 @@ endif()
# Options # Options
option(BUILD_SOX "Build libsox statically" ON) option(BUILD_SOX "Build libsox statically" ON)
option(BUILD_KALDI "Build kaldi statically" ON)
option(BUILD_RIR "Enable RIR simulation" ON) option(BUILD_RIR "Enable RIR simulation" ON)
option(BUILD_RNNT "Enable RNN transducer" ON) option(BUILD_RNNT "Enable RNN transducer" ON)
option(BUILD_ALIGN "Enable forced alignment" ON) option(BUILD_ALIGN "Enable forced alignment" ON)
......
...@@ -406,54 +406,3 @@ def plot_pitch(waveform, sr, pitch): ...@@ -406,54 +406,3 @@ def plot_pitch(waveform, sr, pitch):
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch) plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
######################################################################
# Kaldi Pitch (beta)
# ------------------
#
# Kaldi Pitch feature [1] is a pitch detection mechanism tuned for automatic
# speech recognition (ASR) applications. This is a beta feature in ``torchaudio``,
# and it is available as :py:func:`torchaudio.functional.compute_kaldi_pitch`.
#
# 1. A pitch extraction algorithm tuned for automatic speech recognition
#
# Ghahremani, B. BabaAli, D. Povey, K. Riedhammer, J. Trmal and S.
# Khudanpur
#
# 2014 IEEE International Conference on Acoustics, Speech and Signal
# Processing (ICASSP), Florence, 2014, pp. 2494-2498, doi:
# 10.1109/ICASSP.2014.6854049.
# [`abstract <https://ieeexplore.ieee.org/document/6854049>`__],
# [`paper <https://danielpovey.com/files/2014_icassp_pitch.pdf>`__]
#
pitch_feature = F.compute_kaldi_pitch(SPEECH_WAVEFORM, SAMPLE_RATE)
pitch, nfcc = pitch_feature[..., 0], pitch_feature[..., 1]
######################################################################
#
def plot_kaldi_pitch(waveform, sr, pitch, nfcc):
_, axis = plt.subplots(1, 1)
axis.set_title("Kaldi Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sr
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis.set_ylim((-1.3, 1.3))
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot(time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--")
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
axis.legend(lns, labels, loc=0)
plt.show(block=False)
plot_kaldi_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch, nfcc)
...@@ -124,7 +124,8 @@ def _fetch_archives(src): ...@@ -124,7 +124,8 @@ def _fetch_archives(src):
def _fetch_third_party_libraries(): def _fetch_third_party_libraries():
_init_submodule() # Revert this when a submodule is added again
# _init_submodule()
if os.name != "nt": if os.name != "nt":
_fetch_archives(_parse_sources()) _fetch_archives(_parse_sources())
......
...@@ -13,7 +13,6 @@ from .case_utils import ( ...@@ -13,7 +13,6 @@ from .case_utils import (
skipIfNoExec, skipIfNoExec,
skipIfNoFFmpeg, skipIfNoFFmpeg,
skipIfNoHWAccel, skipIfNoHWAccel,
skipIfNoKaldi,
skipIfNoMacOS, skipIfNoMacOS,
skipIfNoModule, skipIfNoModule,
skipIfNoQengine, skipIfNoQengine,
...@@ -52,7 +51,6 @@ __all__ = [ ...@@ -52,7 +51,6 @@ __all__ = [
"skipIfNoExec", "skipIfNoExec",
"skipIfNoMacOS", "skipIfNoMacOS",
"skipIfNoModule", "skipIfNoModule",
"skipIfNoKaldi",
"skipIfNoRIR", "skipIfNoRIR",
"skipIfNoSox", "skipIfNoSox",
"skipIfNoSoxBackend", "skipIfNoSoxBackend",
......
...@@ -234,11 +234,6 @@ skipIfNoSox = _skipIf( ...@@ -234,11 +234,6 @@ skipIfNoSox = _skipIf(
reason="Sox features are not available.", reason="Sox features are not available.",
key="NO_SOX", key="NO_SOX",
) )
skipIfNoKaldi = _skipIf(
not torchaudio._extension._IS_KALDI_AVAILABLE,
reason="Kaldi features are not available.",
key="NO_KALDI",
)
skipIfNoRIR = _skipIf( skipIfNoRIR = _skipIf(
not torchaudio._extension._IS_RIR_AVAILABLE, not torchaudio._extension._IS_RIR_AVAILABLE,
reason="RIR features are not available.", reason="RIR features are not available.",
......
...@@ -257,18 +257,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -257,18 +257,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
atol=1e-7, atol=1e-7,
) )
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
sample_rate = 44100
n_channels = 2
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=self.batch_size * n_channels)
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
kwargs = {
"sample_rate": sample_rate,
}
func = partial(F.compute_kaldi_pitch, **kwargs)
self.assert_batch_consistency(func, inputs=(batch,))
def test_lfilter(self): def test_lfilter(self):
signal_length = 2048 signal_length = 2048
x = torch.randn(self.batch_size, signal_length) x = torch.randn(self.batch_size, signal_length)
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .kaldi_compatibility_test_impl import Kaldi, KaldiCPUOnly from .kaldi_compatibility_test_impl import Kaldi
class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class TestKaldiFloat32(Kaldi, PytorchTestCase): class TestKaldiFloat32(Kaldi, PytorchTestCase):
......
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
from parameterized import parameterized from torchaudio_unittest.common_utils import skipIfNoExec, TempDirMixin, TestBaseMixin
from torchaudio_unittest.common_utils import (
get_sinusoid,
load_params,
save_wav,
skipIfNoExec,
TempDirMixin,
TestBaseMixin,
)
from torchaudio_unittest.common_utils.kaldi_utils import convert_args, run_kaldi from torchaudio_unittest.common_utils.kaldi_utils import convert_args, run_kaldi
...@@ -32,25 +24,3 @@ class Kaldi(TempDirMixin, TestBaseMixin): ...@@ -32,25 +24,3 @@ class Kaldi(TempDirMixin, TestBaseMixin):
command = ["apply-cmvn-sliding"] + convert_args(**kwargs) + ["ark:-", "ark:-"] command = ["apply-cmvn-sliding"] + convert_args(**kwargs) + ["ark:-", "ark:-"]
kaldi_result = run_kaldi(command, "ark", tensor) kaldi_result = run_kaldi(command, "ark", tensor)
self.assert_equal(result, expected=kaldi_result) self.assert_equal(result, expected=kaldi_result)
class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
@parameterized.expand(load_params("kaldi_test_pitch_args.jsonl"))
@skipIfNoExec("compute-kaldi-pitch-feats")
def test_pitch_feats(self, kwargs):
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
sample_rate = kwargs["sample_rate"]
waveform = get_sinusoid(dtype="float32", sample_rate=sample_rate)
result = F.compute_kaldi_pitch(waveform[0], **kwargs)
waveform = get_sinusoid(dtype="int16", sample_rate=sample_rate)
wave_file = self.get_temp_path("test.wav")
save_wav(wave_file, waveform, sample_rate)
command = ["compute-kaldi-pitch-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
kaldi_result = run_kaldi(command, "scp", wave_file)
self.assert_equal(result, expected=kaldi_result)
...@@ -585,18 +585,6 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -585,18 +585,6 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100) tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, (tensor,)) self._assert_consistency(func, (tensor,))
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device("cpu"):
raise unittest.SkipTest("Only float32, cpu is supported.")
def func(tensor):
sample_rate: float = 44100.0
return F.compute_kaldi_pitch(tensor, sample_rate)
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, (tensor,))
def test_resample_sinc(self): def test_resample_sinc(self):
def func(tensor): def func(tensor):
sr1, sr2 = 16000, 8000 sr1, sr2 = 16000, 8000
......
...@@ -9,10 +9,3 @@ file(MAKE_DIRECTORY install/lib) ...@@ -9,10 +9,3 @@ file(MAKE_DIRECTORY install/lib)
if (BUILD_SOX) if (BUILD_SOX)
add_subdirectory(sox) add_subdirectory(sox)
endif() endif()
################################################################################
# kaldi
################################################################################
if (BUILD_KALDI)
add_subdirectory(kaldi)
endif()
set(KALDI_REPO ${CMAKE_CURRENT_SOURCE_DIR}/submodule)
if (NOT EXISTS ${KALDI_REPO}/src/base/version.h)
# Apply custom patch
execute_process(
WORKING_DIRECTORY ${KALDI_REPO}
COMMAND "git" "checkout" "."
)
execute_process(
WORKING_DIRECTORY ${KALDI_REPO}
COMMAND git apply ../kaldi.patch
)
# Update the version string
execute_process(
WORKING_DIRECTORY ${KALDI_REPO}/src/base
COMMAND bash get_version.sh
)
endif()
set(KALDI_SOURCES
src/matrix/kaldi-vector.cc
src/matrix/kaldi-matrix.cc
submodule/src/base/kaldi-error.cc
submodule/src/base/kaldi-math.cc
submodule/src/feat/feature-functions.cc
submodule/src/feat/pitch-functions.cc
submodule/src/feat/resample.cc
)
add_library(kaldi STATIC ${KALDI_SOURCES})
target_include_directories(kaldi PUBLIC src submodule/src)
target_include_directories(kaldi PRIVATE ${TORCH_INCLUDE_DIRS})
# Custom Kaldi build
This directory contains original Kaldi repository (as submodule), [the custom implementation of Kaldi's vector/matrix](./src) and the build script.
We use the custom build process so that the resulting library only contains what torchaudio needs.
We use the custom vector/matrix implementation so that we can use the same BLAS library that PyTorch is compiled with, and so that we can (hopefully, in future) take advantage of other PyTorch features (such as differentiability and GPU support). The down side of this approach is that it adds a lot of overhead compared to the original Kaldi (operator dispatch and element-wise processing, which PyTorch is not efficient at). We can improve this gradually, and if you are interested in helping, please let us know by opening an issue.
diff --git a/src/base/kaldi-types.h b/src/base/kaldi-types.h
index 7ebf4f853..c15b288b2 100644
--- a/src/base/kaldi-types.h
+++ b/src/base/kaldi-types.h
@@ -41,6 +41,7 @@ typedef float BaseFloat;
// for discussion on what to do if you need compile kaldi
// without OpenFST, see the bottom of this this file
+/*
#include <fst/types.h>
namespace kaldi {
@@ -53,10 +54,10 @@ namespace kaldi {
typedef float float32;
typedef double double64;
} // end namespace kaldi
+*/
// In a theoretical case you decide compile Kaldi without the OpenFST
// comment the previous namespace statement and uncomment the following
-/*
namespace kaldi {
typedef int8_t int8;
typedef int16_t int16;
@@ -70,6 +71,5 @@ namespace kaldi {
typedef float float32;
typedef double double64;
} // end namespace kaldi
-*/
#endif // KALDI_BASE_KALDI_TYPES_H_
diff --git a/src/matrix/matrix-lib.h b/src/matrix/matrix-lib.h
index b6059b06c..4fb9e1b16 100644
--- a/src/matrix/matrix-lib.h
+++ b/src/matrix/matrix-lib.h
@@ -25,14 +25,14 @@
#include "base/kaldi-common.h"
#include "matrix/kaldi-vector.h"
#include "matrix/kaldi-matrix.h"
-#include "matrix/sp-matrix.h"
-#include "matrix/tp-matrix.h"
+// #include "matrix/sp-matrix.h"
+// #include "matrix/tp-matrix.h"
#include "matrix/matrix-functions.h"
#include "matrix/srfft.h"
#include "matrix/compressed-matrix.h"
-#include "matrix/sparse-matrix.h"
+// #include "matrix/sparse-matrix.h"
#include "matrix/optimization.h"
-#include "matrix/numpy-array.h"
+// #include "matrix/numpy-array.h"
#endif
diff --git a/src/util/common-utils.h b/src/util/common-utils.h
index cfb0c255c..48d199e97 100644
--- a/src/util/common-utils.h
+++ b/src/util/common-utils.h
@@ -21,11 +21,11 @@
#include "base/kaldi-common.h"
#include "util/parse-options.h"
-#include "util/kaldi-io.h"
-#include "util/simple-io-funcs.h"
-#include "util/kaldi-holder.h"
-#include "util/kaldi-table.h"
-#include "util/table-types.h"
-#include "util/text-utils.h"
+// #include "util/kaldi-io.h"
+// #include "util/simple-io-funcs.h"
+// #include "util/kaldi-holder.h"
+// #include "util/kaldi-table.h"
+// #include "util/table-types.h"
+// #include "util/text-utils.h"
#endif // KALDI_UTIL_COMMON_UTILS_H_
#include "matrix/kaldi-matrix.h"
#include <torch/torch.h>
namespace {
template <typename Real>
void assert_matrix_shape(const torch::Tensor& tensor_);
template <>
void assert_matrix_shape<float>(const torch::Tensor& tensor_) {
TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2);
TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32);
TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}
template <>
void assert_matrix_shape<double>(const torch::Tensor& tensor_) {
TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2);
TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64);
TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}
} // namespace
namespace kaldi {
template <typename Real>
MatrixBase<Real>::MatrixBase(torch::Tensor tensor) : tensor_(tensor) {
assert_matrix_shape<Real>(tensor_);
};
template class Matrix<float>;
template class Matrix<double>;
template class MatrixBase<float>;
template class MatrixBase<double>;
template class SubMatrix<float>;
template class SubMatrix<double>;
} // namespace kaldi
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h
#ifndef KALDI_MATRIX_KALDI_MATRIX_H_
#define KALDI_MATRIX_KALDI_MATRIX_H_
#include <torch/torch.h>
#include "matrix/kaldi-vector.h"
#include "matrix/matrix-common.h"
using namespace torch::indexing;
namespace kaldi {
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L44-L48
template <typename Real>
class MatrixBase {
public:
////////////////////////////////////////////////////////////////////////////////
// PyTorch-specific items
////////////////////////////////////////////////////////////////////////////////
torch::Tensor tensor_;
/// Construct VectorBase which is an interface to an existing torch::Tensor
/// object.
MatrixBase(torch::Tensor tensor);
////////////////////////////////////////////////////////////////////////////////
// Kaldi-compatible items
////////////////////////////////////////////////////////////////////////////////
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L62-L63
inline MatrixIndexT NumRows() const {
return tensor_.size(0);
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L65-L66
inline MatrixIndexT NumCols() const {
return tensor_.size(1);
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L177-L178
void CopyColFromVec(const VectorBase<Real>& v, const MatrixIndexT col) {
tensor_.index_put_({Slice(), col}, v.tensor_);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L99-L107
inline Real& operator()(MatrixIndexT r, MatrixIndexT c) {
// CPU only
return tensor_.accessor<Real, 2>()[r][c];
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L112-L120
inline const Real operator()(MatrixIndexT r, MatrixIndexT c) const {
return tensor_.index({Slice(r), Slice(c)}).item().template to<Real>();
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L138-L141
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L859-L898
template <typename OtherReal>
void CopyFromMat(
const MatrixBase<OtherReal>& M,
MatrixTransposeType trans = kNoTrans) {
auto src = M.tensor_;
if (trans == kTrans)
src = src.transpose(1, 0);
tensor_.index_put_({Slice(), Slice()}, src);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L186-L191
inline const SubVector<Real> Row(MatrixIndexT i) const {
return SubVector<Real>(*this, i);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L208-L211
inline SubMatrix<Real> RowRange(
const MatrixIndexT row_offset,
const MatrixIndexT num_rows) const {
return SubMatrix<Real>(*this, row_offset, num_rows, 0, NumCols());
}
protected:
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L749-L753
explicit MatrixBase() : tensor_(torch::empty({0, 0})) {
KALDI_ASSERT_IS_FLOATING_TYPE(Real);
}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L781-L784
template <typename Real>
class Matrix : public MatrixBase<Real> {
public:
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L786-L787
Matrix() : MatrixBase<Real>() {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L789-L793
Matrix(
const MatrixIndexT r,
const MatrixIndexT c,
MatrixResizeType resize_type = kSetZero,
MatrixStrideType stride_type = kDefaultStride)
: MatrixBase<Real>() {
Resize(r, c, resize_type, stride_type);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L808-L811
explicit Matrix(
const MatrixBase<Real>& M,
MatrixTransposeType trans = kNoTrans)
: MatrixBase<Real>(
trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L816-L819
template <typename OtherReal>
explicit Matrix(
const MatrixBase<OtherReal>& M,
MatrixTransposeType trans = kNoTrans)
: MatrixBase<Real>(
trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L859-L874
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L817-L857
void Resize(
const MatrixIndexT r,
const MatrixIndexT c,
MatrixResizeType resize_type = kSetZero,
MatrixStrideType stride_type = kDefaultStride) {
auto& tensor_ = MatrixBase<Real>::tensor_;
switch (resize_type) {
case kSetZero:
tensor_.resize_({r, c}).zero_();
break;
case kUndefined:
tensor_.resize_({r, c});
break;
case kCopyData:
auto tmp = tensor_;
auto tmp_rows = tmp.size(0);
auto tmp_cols = tmp.size(1);
tensor_.resize_({r, c}).zero_();
auto rows = Slice(None, r < tmp_rows ? r : tmp_rows);
auto cols = Slice(None, c < tmp_cols ? c : tmp_cols);
tensor_.index_put_({rows, cols}, tmp.index({rows, cols}));
break;
}
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L876-L883
Matrix<Real>& operator=(const MatrixBase<Real>& other) {
if (MatrixBase<Real>::NumRows() != other.NumRows() ||
MatrixBase<Real>::NumCols() != other.NumCols())
Resize(other.NumRows(), other.NumCols(), kUndefined);
MatrixBase<Real>::CopyFromMat(other);
return *this;
}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L940-L948
template <typename Real>
class SubMatrix : public MatrixBase<Real> {
public:
SubMatrix(
const MatrixBase<Real>& T,
const MatrixIndexT ro, // row offset, 0 < ro < NumRows()
const MatrixIndexT r, // number of rows, r > 0
const MatrixIndexT co, // column offset, 0 < co < NumCols()
const MatrixIndexT c) // number of columns, c > 0
: MatrixBase<Real>(
T.tensor_.index({Slice(ro, ro + r), Slice(co, co + c)})) {}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L1059-L1060
template <typename Real>
std::ostream& operator<<(std::ostream& Out, const MatrixBase<Real>& M) {
Out << M.tensor_;
return Out;
}
} // namespace kaldi
#endif
#include "matrix/kaldi-vector.h"
#include <torch/torch.h>
#include "matrix/kaldi-matrix.h"
namespace {
template <typename Real>
void assert_vector_shape(const torch::Tensor& tensor_);
template <>
void assert_vector_shape<float>(const torch::Tensor& tensor_) {
TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1);
TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32);
TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}
template <>
void assert_vector_shape<double>(const torch::Tensor& tensor_) {
TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1);
TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64);
TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}
} // namespace
namespace kaldi {
template <typename Real>
VectorBase<Real>::VectorBase(torch::Tensor tensor)
: tensor_(tensor), data_(tensor.data_ptr<Real>()) {
assert_vector_shape<Real>(tensor_);
};
template <typename Real>
VectorBase<Real>::VectorBase() : VectorBase<Real>(torch::empty({0})) {}
template class Vector<float>;
template class Vector<double>;
template class VectorBase<float>;
template class VectorBase<double>;
} // namespace kaldi
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h
#ifndef KALDI_MATRIX_KALDI_VECTOR_H_
#define KALDI_MATRIX_KALDI_VECTOR_H_
#include <torch/torch.h>
#include "matrix/matrix-common.h"
using namespace torch::indexing;
namespace kaldi {
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L36-L40
template <typename Real>
class VectorBase {
public:
////////////////////////////////////////////////////////////////////////////////
// PyTorch-specific things
////////////////////////////////////////////////////////////////////////////////
torch::Tensor tensor_;
/// Construct VectorBase which is an interface to an existing torch::Tensor
/// object.
VectorBase(torch::Tensor tensor);
////////////////////////////////////////////////////////////////////////////////
// Kaldi-compatible methods
////////////////////////////////////////////////////////////////////////////////
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L42-L43
void SetZero() {
Set(0);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L48-L49
void Set(Real f) {
tensor_.index_put_({"..."}, f);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L62-L63
inline MatrixIndexT Dim() const {
return tensor_.numel();
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L68-L69
inline Real* Data() {
return data_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L71-L72
inline const Real* Data() const {
return data_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L74-L79
inline Real operator()(MatrixIndexT i) const {
return data_[i];
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L81-L86
inline Real& operator()(MatrixIndexT i) {
return tensor_.accessor<Real, 1>()[i];
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L88-L95
SubVector<Real> Range(const MatrixIndexT o, const MatrixIndexT l) {
return SubVector<Real>(*this, o, l);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L97-L105
const SubVector<Real> Range(const MatrixIndexT o, const MatrixIndexT l)
const {
return SubVector<Real>(*this, o, l);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L107-L108
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L226-L233
void CopyFromVec(const VectorBase<Real>& v) {
TORCH_INTERNAL_ASSERT(tensor_.sizes() == v.tensor_.sizes());
tensor_.copy_(v.tensor_);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L137-L139
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L816-L832
void ApplyFloor(Real floor_val, MatrixIndexT* floored_count = nullptr) {
auto index = tensor_ < floor_val;
auto tmp = tensor_.index_put_({index}, floor_val);
if (floored_count) {
*floored_count = index.sum().item().template to<MatrixIndexT>();
}
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L164-L165
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L449-L479
void ApplyPow(Real power) {
tensor_.pow_(power);
TORCH_INTERNAL_ASSERT(!tensor_.isnan().sum().item().template to<int32_t>());
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L181-L184
template <typename OtherReal>
void AddVec(const Real alpha, const VectorBase<OtherReal>& v) {
tensor_ += alpha * v.tensor_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L186-L187
void AddVec2(const Real alpha, const VectorBase<Real>& v) {
tensor_ += alpha * (v.tensor_.square());
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L196-L198
void AddMatVec(
const Real alpha,
const MatrixBase<Real>& M,
const MatrixTransposeType trans,
const VectorBase<Real>& v,
const Real beta) { // **beta previously defaulted to 0.0**
auto mat = M.tensor_;
if (trans == kTrans) {
mat = mat.transpose(1, 0);
}
tensor_.addmv_(mat, v.tensor_, beta, alpha);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L221-L222
void MulElements(const VectorBase<Real>& v) {
tensor_ *= v.tensor_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L233-L234
void Add(Real c) {
tensor_ += c;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L236-L239
void AddVecVec(
Real alpha,
const VectorBase<Real>& v,
const VectorBase<Real>& r,
Real beta) {
tensor_ = beta * tensor_ + alpha * v.tensor_ * r.tensor_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L246-L247
void Scale(Real alpha) {
tensor_ *= alpha;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L305-L306
Real Min() const {
if (tensor_.numel()) {
return tensor_.min().item().template to<Real>();
}
return std::numeric_limits<Real>::infinity();
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L308-L310
Real Min(MatrixIndexT* index) const {
TORCH_INTERNAL_ASSERT(tensor_.numel());
torch::Tensor value, ind;
std::tie(value, ind) = tensor_.min(0);
*index = ind.item().to<MatrixIndexT>();
return value.item().to<Real>();
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L312-L313
Real Sum() const {
return tensor_.sum().item().template to<Real>();
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L320-L321
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L718-L736
void AddRowSumMat(Real alpha, const MatrixBase<Real>& M, Real beta = 1.0) {
Vector<Real> ones(M.NumRows());
ones.Set(1.0);
this->AddMatVec(alpha, M, kTrans, ones, beta);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L323-L324
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L738-L757
void AddColSumMat(Real alpha, const MatrixBase<Real>& M, Real beta = 1.0) {
Vector<Real> ones(M.NumCols());
ones.Set(1.0);
this->AddMatVec(alpha, M, kNoTrans, ones, beta);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L326-L330
void AddDiagMat2(
Real alpha,
const MatrixBase<Real>& M,
MatrixTransposeType trans = kNoTrans,
Real beta = 1.0) {
auto mat = M.tensor_;
if (trans == kNoTrans) {
tensor_ =
beta * tensor_ + torch::diag(torch::mm(mat, mat.transpose(1, 0)));
} else {
tensor_ =
beta * tensor_ + torch::diag(torch::mm(mat.transpose(1, 0), mat));
}
}
protected:
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L362-L365
explicit VectorBase();
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L378-L379
Real* data_;
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L382
KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase);
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L385-L390
template <typename Real>
class Vector : public VectorBase<Real> {
public:
////////////////////////////////////////////////////////////////////////////////
// PyTorch-compatibility things
////////////////////////////////////////////////////////////////////////////////
/// Construct VectorBase which is an interface to an existing torch::Tensor
/// object.
Vector(torch::Tensor tensor) : VectorBase<Real>(tensor){};
////////////////////////////////////////////////////////////////////////////////
// Kaldi-compatible methods
////////////////////////////////////////////////////////////////////////////////
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L392-L393
Vector() : VectorBase<Real>(){};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L395-L399
explicit Vector(const MatrixIndexT s, MatrixResizeType resize_type = kSetZero)
: VectorBase<Real>() {
Resize(s, resize_type);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L406-L410
// Note: unlike the original implementation, this is "explicit".
explicit Vector(const Vector<Real>& v)
: VectorBase<Real>(v.tensor_.clone()) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L412-L416
explicit Vector(const VectorBase<Real>& v)
: VectorBase<Real>(v.tensor_.clone()) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L434-L435
void Swap(Vector<Real>* other) {
auto tmp = VectorBase<Real>::tensor_;
this->tensor_ = other->tensor_;
other->tensor_ = tmp;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L444-L451
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L189-L223
void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero) {
auto& tensor_ = this->tensor_;
switch (resize_type) {
case kSetZero:
tensor_.resize_({length}).zero_();
break;
case kUndefined:
tensor_.resize_({length});
break;
case kCopyData:
auto tmp = tensor_;
auto tmp_numel = tensor_.numel();
tensor_.resize_({length}).zero_();
auto numel = Slice(length < tmp_numel ? length : tmp_numel);
tensor_.index_put_({numel}, tmp.index({numel}));
break;
}
// data_ptr<Real>() causes compiler error
this->data_ = static_cast<Real*>(tensor_.data_ptr());
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L463-L468
Vector<Real>& operator=(const VectorBase<Real>& other) {
Resize(other.Dim(), kUndefined);
this->CopyFromVec(other);
return *this;
}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L482-L485
template <typename Real>
class SubVector : public VectorBase<Real> {
public:
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L487-L499
SubVector(
const VectorBase<Real>& t,
const MatrixIndexT origin,
const MatrixIndexT length)
: VectorBase<Real>(t.tensor_.index({Slice(origin, origin + length)})) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L524-L528
SubVector(const MatrixBase<Real>& matrix, MatrixIndexT row)
: VectorBase<Real>(matrix.tensor_.index({row})) {}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L540-L543
template <typename Real>
std::ostream& operator<<(std::ostream& out, const VectorBase<Real>& v) {
out << v.tensor_;
return out;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L573-L575
template <typename Real>
Real VecVec(const VectorBase<Real>& v1, const VectorBase<Real>& v2) {
return torch::dot(v1.tensor_, v2.tensor_).item().template to<Real>();
}
} // namespace kaldi
#endif
Subproject commit 3eea37dd09b55064e6362216f7e9a60641f29f09
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment