Unverified Commit 7ee1c46b authored by moto's avatar moto Committed by GitHub
Browse files

Add Kaldi Pitch feature (#1243)

parent 9e58e75c
...@@ -38,7 +38,7 @@ fi ...@@ -38,7 +38,7 @@ fi
printf "\x1b[34mRunning clang-format:\x1b[0m\n" printf "\x1b[34mRunning clang-format:\x1b[0m\n"
"${this_dir}"/run_clang_format.py \ "${this_dir}"/run_clang_format.py \
-r torchaudio/csrc \ -r torchaudio/csrc third_party/kaldi/src \
--clang-format-executable "${clangformat_path}" \ --clang-format-executable "${clangformat_path}" \
&& git diff --exit-code && git diff --exit-code
status=$? status=$?
......
...@@ -2,3 +2,7 @@ ...@@ -2,3 +2,7 @@
path = third_party/transducer/submodule path = third_party/transducer/submodule
url = https://github.com/HawkAaron/warp-transducer url = https://github.com/HawkAaron/warp-transducer
ignore = dirty ignore = dirty
[submodule "kaldi"]
path = third_party/kaldi/submodule
url = https://github.com/kaldi-asr/kaldi
ignore = dirty
...@@ -47,6 +47,7 @@ endif() ...@@ -47,6 +47,7 @@ endif()
# Options # Options
option(BUILD_SOX "Build libsox statically" OFF) option(BUILD_SOX "Build libsox statically" OFF)
option(BUILD_KALDI "Build kaldi statically" ON)
option(BUILD_TRANSDUCER "Enable transducer" OFF) option(BUILD_TRANSDUCER "Enable transducer" OFF)
option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON) option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON)
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF) option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
......
...@@ -68,6 +68,7 @@ class CMakeBuild(build_ext): ...@@ -68,6 +68,7 @@ class CMakeBuild(build_ext):
'-DCMAKE_VERBOSE_MAKEFILE=ON', '-DCMAKE_VERBOSE_MAKEFILE=ON',
f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}", f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}",
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}", f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
"-DBUILD_KALDI:BOOL=ON",
f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}", f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}",
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON", "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF", "-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
......
{"sample_rate": 8000}
{"sample_rate": 8000, "frames_per_chunk": 200}
{"sample_rate": 8000, "frames_per_chunk": 200, "simulate_first_pass_online": true}
{"sample_rate": 16000}
{"sample_rate": 44100}
import subprocess
import torch
def convert_args(**kwargs):
args = []
for key, value in kwargs.items():
if key == 'sample_rate':
key = 'sample_frequency'
key = '--' + key.replace('_', '-')
value = str(value).lower() if value in [True, False] else str(value)
args.append('%s=%s' % (key, value))
return args
def run_kaldi(command, input_type, input_value):
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
Args:
input_type: str
'ark' or 'scp'
input_value:
Tensor for 'ark'
string for 'scp' (path to an audio file)
"""
import kaldi_io
key = 'foo'
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
if input_type == 'ark':
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
elif input_type == 'scp':
process.stdin.write(f'{key} {input_value}'.encode('utf8'))
else:
raise NotImplementedError('Unexpected type')
process.stdin.close()
result = dict(kaldi_io.read_mat_ark(process.stdout))['foo']
return torch.from_numpy(result.copy()) # copy supresses some torch warning
...@@ -184,3 +184,9 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -184,3 +184,9 @@ class TestFunctional(common_utils.TorchaudioTestCase):
waveform, sample_rate = torchaudio.load(filepath) waveform, sample_rate = torchaudio.load(filepath)
self.assert_batch_consistencies( self.assert_batch_consistencies(
F.vad, waveform, sample_rate=sample_rate) F.vad, waveform, sample_rate=sample_rate)
@common_utils.skipIfNoExtension
def test_compute_kaldi_pitch(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self.assert_batch_consistencies(F.compute_kaldi_pitch, waveform, sample_rate=sample_rate)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .kaldi_compatibility_test_impl import KaldiCPUOnly
class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
from parameterized import parameterized
import torchaudio.functional as F
from torchaudio_unittest.common_utils import (
get_sinusoid,
load_params,
save_wav,
skipIfNoExec,
TempDirMixin,
TestBaseMixin,
)
from torchaudio_unittest.common_utils.kaldi_utils import (
convert_args,
run_kaldi,
)
class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
@parameterized.expand(load_params('kaldi_test_pitch_args.json'))
@skipIfNoExec('compute-kaldi-pitch-feats')
def test_pitch_feats(self, kwargs):
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
sample_rate = kwargs['sample_rate']
waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate)
result = F.compute_kaldi_pitch(waveform[0], **kwargs)
waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate)
wave_file = self.get_temp_path('test.wav')
save_wav(wave_file, waveform, sample_rate)
command = ['compute-kaldi-pitch-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result)
...@@ -547,3 +547,15 @@ class Functional(common_utils.TestBaseMixin): ...@@ -547,3 +547,15 @@ class Functional(common_utils.TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100) tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor) self._assert_consistency(func, tensor)
@common_utils.skipIfNoExtension
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device('cpu'):
raise unittest.SkipTest("Only float32, cpu is supported.")
def func(tensor):
sample_rate: float = 44100.
return F.compute_kaldi_pitch(tensor, sample_rate)
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)
"""Test suites for checking numerical compatibility against Kaldi""" """Test suites for checking numerical compatibility against Kaldi"""
import subprocess
import kaldi_io
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
import torchaudio.compliance.kaldi import torchaudio.compliance.kaldi
...@@ -9,46 +6,19 @@ from parameterized import parameterized ...@@ -9,46 +6,19 @@ from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
TempDirMixin,
load_params, load_params,
skipIfNoExec, skipIfNoExec,
get_asset_path, get_asset_path,
load_wav load_wav,
)
from torchaudio_unittest.common_utils.kaldi_utils import (
convert_args,
run_kaldi,
) )
def _convert_args(**kwargs): class Kaldi(TempDirMixin, TestBaseMixin):
args = []
for key, value in kwargs.items():
key = '--' + key.replace('_', '-')
value = str(value).lower() if value in [True, False] else str(value)
args.append('%s=%s' % (key, value))
return args
def _run_kaldi(command, input_type, input_value):
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
Args:
input_type: str
'ark' or 'scp'
input_value:
Tensor for 'ark'
string for 'scp' (path to an audio file)
"""
key = 'foo'
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
if input_type == 'ark':
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
elif input_type == 'scp':
process.stdin.write(f'{key} {input_value}'.encode('utf8'))
else:
raise NotImplementedError('Unexpected type')
process.stdin.close()
result = dict(kaldi_io.read_mat_ark(process.stdout))['foo']
return torch.from_numpy(result.copy()) # copy supresses some torch warning
class Kaldi(TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None): def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device) expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol) self.assertEqual(output, expected, rtol=rtol, atol=atol)
...@@ -65,8 +35,8 @@ class Kaldi(TestBaseMixin): ...@@ -65,8 +35,8 @@ class Kaldi(TestBaseMixin):
tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device) tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs) result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-'] command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'ark', tensor) kaldi_result = run_kaldi(command, 'ark', tensor)
self.assert_equal(result, expected=kaldi_result) self.assert_equal(result, expected=kaldi_result)
@parameterized.expand(load_params('kaldi_test_fbank_args.json')) @parameterized.expand(load_params('kaldi_test_fbank_args.json'))
...@@ -76,8 +46,8 @@ class Kaldi(TestBaseMixin): ...@@ -76,8 +46,8 @@ class Kaldi(TestBaseMixin):
wave_file = get_asset_path('kaldi_file.wav') wave_file = get_asset_path('kaldi_file.wav')
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs) result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ['compute-fbank-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_spectrogram_args.json')) @parameterized.expand(load_params('kaldi_test_spectrogram_args.json'))
...@@ -87,8 +57,8 @@ class Kaldi(TestBaseMixin): ...@@ -87,8 +57,8 @@ class Kaldi(TestBaseMixin):
wave_file = get_asset_path('kaldi_file.wav') wave_file = get_asset_path('kaldi_file.wav')
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs) result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
command = ['compute-spectrogram-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ['compute-spectrogram-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(load_params('kaldi_test_mfcc_args.json')) @parameterized.expand(load_params('kaldi_test_mfcc_args.json'))
...@@ -98,6 +68,6 @@ class Kaldi(TestBaseMixin): ...@@ -98,6 +68,6 @@ class Kaldi(TestBaseMixin):
wave_file = get_asset_path('kaldi_file.wav') wave_file = get_asset_path('kaldi_file.wav')
waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs) result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ['compute-mfcc-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
...@@ -17,6 +17,14 @@ else() ...@@ -17,6 +17,14 @@ else()
endif() endif()
list(APPEND TORCHAUDIO_THIRD_PARTIES libsox) list(APPEND TORCHAUDIO_THIRD_PARTIES libsox)
################################################################################
# kaldi
################################################################################
if (BUILD_KALDI)
add_subdirectory(kaldi)
list(APPEND TORCHAUDIO_THIRD_PARTIES kaldi)
endif()
################################################################################ ################################################################################
# transducer # transducer
################################################################################ ################################################################################
......
set(KALDI_REPO ${CMAKE_CURRENT_SOURCE_DIR}/submodule)
# Apply custom patch
execute_process(
WORKING_DIRECTORY ${KALDI_REPO}
COMMAND "git" "checkout" "."
)
execute_process(
WORKING_DIRECTORY ${KALDI_REPO}
COMMAND git apply ../kaldi.patch
)
# Update the version string
execute_process(
WORKING_DIRECTORY ${KALDI_REPO}/src/base
COMMAND sh get_version.sh
)
set(KALDI_SOURCES
src/matrix/kaldi-vector.cc
src/matrix/kaldi-matrix.cc
submodule/src/base/kaldi-error.cc
submodule/src/base/kaldi-math.cc
submodule/src/feat/feature-functions.cc
submodule/src/feat/pitch-functions.cc
submodule/src/feat/resample.cc
)
add_library(kaldi STATIC ${KALDI_SOURCES})
target_include_directories(kaldi PUBLIC src submodule/src)
target_link_libraries(kaldi ${TORCH_LIBRARIES})
# Custom Kaldi build
This directory contains original Kaldi repository (as submodule), [the custom implementation of Kaldi's vector/matrix](./src) and the build script.
We use the custom build process so that the resulting library only contains what torchaudio needs.
We use the custom vector/matrix implementation so that we can use the same BLAS library that PyTorch is compiled with, and so that we can (hopefully, in future) take advantage of other PyTorch features (such as differentiability and GPU support). The down side of this approach is that it adds a lot of overhead compared to the original Kaldi (operator dispatch and element-wise processing, which PyTorch is not efficient at). We can improve this gradually, and if you are interested in helping, please let us know by opening an issue.
\ No newline at end of file
diff --git a/src/base/kaldi-types.h b/src/base/kaldi-types.h
index 7ebf4f853..c15b288b2 100644
--- a/src/base/kaldi-types.h
+++ b/src/base/kaldi-types.h
@@ -41,6 +41,7 @@ typedef float BaseFloat;
// for discussion on what to do if you need compile kaldi
// without OpenFST, see the bottom of this this file
+/*
#include <fst/types.h>
namespace kaldi {
@@ -53,10 +54,10 @@ namespace kaldi {
typedef float float32;
typedef double double64;
} // end namespace kaldi
+*/
// In a theoretical case you decide compile Kaldi without the OpenFST
// comment the previous namespace statement and uncomment the following
-/*
namespace kaldi {
typedef int8_t int8;
typedef int16_t int16;
@@ -70,6 +71,5 @@ namespace kaldi {
typedef float float32;
typedef double double64;
} // end namespace kaldi
-*/
#endif // KALDI_BASE_KALDI_TYPES_H_
diff --git a/src/matrix/matrix-lib.h b/src/matrix/matrix-lib.h
index b6059b06c..4fb9e1b16 100644
--- a/src/matrix/matrix-lib.h
+++ b/src/matrix/matrix-lib.h
@@ -25,14 +25,14 @@
#include "base/kaldi-common.h"
#include "matrix/kaldi-vector.h"
#include "matrix/kaldi-matrix.h"
-#include "matrix/sp-matrix.h"
-#include "matrix/tp-matrix.h"
+// #include "matrix/sp-matrix.h"
+// #include "matrix/tp-matrix.h"
#include "matrix/matrix-functions.h"
#include "matrix/srfft.h"
#include "matrix/compressed-matrix.h"
-#include "matrix/sparse-matrix.h"
+// #include "matrix/sparse-matrix.h"
#include "matrix/optimization.h"
-#include "matrix/numpy-array.h"
+// #include "matrix/numpy-array.h"
#endif
diff --git a/src/util/common-utils.h b/src/util/common-utils.h
index cfb0c255c..48d199e97 100644
--- a/src/util/common-utils.h
+++ b/src/util/common-utils.h
@@ -21,11 +21,11 @@
#include "base/kaldi-common.h"
#include "util/parse-options.h"
-#include "util/kaldi-io.h"
-#include "util/simple-io-funcs.h"
-#include "util/kaldi-holder.h"
-#include "util/kaldi-table.h"
-#include "util/table-types.h"
-#include "util/text-utils.h"
+// #include "util/kaldi-io.h"
+// #include "util/simple-io-funcs.h"
+// #include "util/kaldi-holder.h"
+// #include "util/kaldi-table.h"
+// #include "util/table-types.h"
+// #include "util/text-utils.h"
#endif // KALDI_UTIL_COMMON_UTILS_H_
#include "matrix/kaldi-matrix.h"
#include <torch/torch.h>
namespace {
template <typename Real>
void assert_matrix_shape(const torch::Tensor& tensor_);
template <>
void assert_matrix_shape<float>(const torch::Tensor& tensor_) {
TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2);
TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32);
TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}
template <>
void assert_matrix_shape<double>(const torch::Tensor& tensor_) {
TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2);
TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64);
TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}
} // namespace
namespace kaldi {
template <typename Real>
MatrixBase<Real>::MatrixBase(torch::Tensor tensor) : tensor_(tensor) {
assert_matrix_shape<Real>(tensor_);
};
template class Matrix<float>;
template class Matrix<double>;
template class MatrixBase<float>;
template class MatrixBase<double>;
template class SubMatrix<float>;
template class SubMatrix<double>;
} // namespace kaldi
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h
#ifndef KALDI_MATRIX_KALDI_MATRIX_H_
#define KALDI_MATRIX_KALDI_MATRIX_H_
#include <torch/torch.h>
#include "matrix/kaldi-vector.h"
#include "matrix/matrix-common.h"
using namespace torch::indexing;
namespace kaldi {
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L44-L48
template <typename Real>
class MatrixBase {
public:
////////////////////////////////////////////////////////////////////////////////
// PyTorch-specific items
////////////////////////////////////////////////////////////////////////////////
torch::Tensor tensor_;
/// Construct VectorBase which is an interface to an existing torch::Tensor
/// object.
MatrixBase(torch::Tensor tensor);
////////////////////////////////////////////////////////////////////////////////
// Kaldi-compatible items
////////////////////////////////////////////////////////////////////////////////
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L62-L63
inline MatrixIndexT NumRows() const {
return tensor_.size(0);
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L65-L66
inline MatrixIndexT NumCols() const {
return tensor_.size(1);
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L177-L178
void CopyColFromVec(const VectorBase<Real>& v, const MatrixIndexT col) {
tensor_.index_put_({Slice(), col}, v.tensor_);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L99-L107
inline Real& operator()(MatrixIndexT r, MatrixIndexT c) {
// CPU only
return tensor_.accessor<Real, 2>()[r][c];
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L112-L120
inline const Real operator()(MatrixIndexT r, MatrixIndexT c) const {
return tensor_.index({Slice(r), Slice(c)}).item().template to<Real>();
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L138-L141
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L859-L898
template <typename OtherReal>
void CopyFromMat(
const MatrixBase<OtherReal>& M,
MatrixTransposeType trans = kNoTrans) {
auto src = M.tensor_;
if (trans == kTrans)
src = src.transpose(1, 0);
tensor_.index_put_({Slice(), Slice()}, src);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L186-L191
inline const SubVector<Real> Row(MatrixIndexT i) const {
return SubVector<Real>(*this, i);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L208-L211
inline SubMatrix<Real> RowRange(
const MatrixIndexT row_offset,
const MatrixIndexT num_rows) const {
return SubMatrix<Real>(*this, row_offset, num_rows, 0, NumCols());
}
protected:
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L749-L753
explicit MatrixBase() : tensor_(torch::empty({0, 0})) {
KALDI_ASSERT_IS_FLOATING_TYPE(Real);
}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L781-L784
template <typename Real>
class Matrix : public MatrixBase<Real> {
public:
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L786-L787
Matrix() : MatrixBase<Real>() {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L789-L793
Matrix(
const MatrixIndexT r,
const MatrixIndexT c,
MatrixResizeType resize_type = kSetZero,
MatrixStrideType stride_type = kDefaultStride)
: MatrixBase<Real>() {
Resize(r, c, resize_type, stride_type);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L808-L811
explicit Matrix(
const MatrixBase<Real>& M,
MatrixTransposeType trans = kNoTrans)
: MatrixBase<Real>(
trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L816-L819
template <typename OtherReal>
explicit Matrix(
const MatrixBase<OtherReal>& M,
MatrixTransposeType trans = kNoTrans)
: MatrixBase<Real>(
trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L859-L874
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L817-L857
void Resize(
const MatrixIndexT r,
const MatrixIndexT c,
MatrixResizeType resize_type = kSetZero,
MatrixStrideType stride_type = kDefaultStride) {
auto& tensor_ = MatrixBase<Real>::tensor_;
switch (resize_type) {
case kSetZero:
tensor_.resize_({r, c}).zero_();
break;
case kUndefined:
tensor_.resize_({r, c});
break;
case kCopyData:
auto tmp = tensor_;
auto tmp_rows = tmp.size(0);
auto tmp_cols = tmp.size(1);
tensor_.resize_({r, c}).zero_();
auto rows = Slice(None, r < tmp_rows ? r : tmp_rows);
auto cols = Slice(None, c < tmp_cols ? c : tmp_cols);
tensor_.index_put_({rows, cols}, tmp.index({rows, cols}));
break;
}
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L876-L883
Matrix<Real>& operator=(const MatrixBase<Real>& other) {
if (MatrixBase<Real>::NumRows() != other.NumRows() ||
MatrixBase<Real>::NumCols() != other.NumCols())
Resize(other.NumRows(), other.NumCols(), kUndefined);
MatrixBase<Real>::CopyFromMat(other);
return *this;
}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L940-L948
template <typename Real>
class SubMatrix : public MatrixBase<Real> {
public:
SubMatrix(
const MatrixBase<Real>& T,
const MatrixIndexT ro, // row offset, 0 < ro < NumRows()
const MatrixIndexT r, // number of rows, r > 0
const MatrixIndexT co, // column offset, 0 < co < NumCols()
const MatrixIndexT c) // number of columns, c > 0
: MatrixBase<Real>(
T.tensor_.index({Slice(ro, ro + r), Slice(co, co + c)})) {}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L1059-L1060
template <typename Real>
std::ostream& operator<<(std::ostream& Out, const MatrixBase<Real>& M) {
Out << M.tensor_;
return Out;
}
} // namespace kaldi
#endif
#include "matrix/kaldi-vector.h"
#include <torch/torch.h>
#include "matrix/kaldi-matrix.h"
namespace {
template <typename Real>
void assert_vector_shape(const torch::Tensor& tensor_);
template <>
void assert_vector_shape<float>(const torch::Tensor& tensor_) {
TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1);
TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32);
TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}
template <>
void assert_vector_shape<double>(const torch::Tensor& tensor_) {
TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1);
TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64);
TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}
} // namespace
namespace kaldi {
template <typename Real>
VectorBase<Real>::VectorBase(torch::Tensor tensor)
: tensor_(tensor), data_(tensor.data_ptr<Real>()) {
assert_vector_shape<Real>(tensor_);
};
template <typename Real>
VectorBase<Real>::VectorBase() : VectorBase<Real>(torch::empty({0})) {}
template class Vector<float>;
template class Vector<double>;
template class VectorBase<float>;
template class VectorBase<double>;
} // namespace kaldi
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h
#ifndef KALDI_MATRIX_KALDI_VECTOR_H_
#define KALDI_MATRIX_KALDI_VECTOR_H_
#include <torch/torch.h>
#include "matrix/matrix-common.h"
using namespace torch::indexing;
namespace kaldi {
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L36-L40
template <typename Real>
class VectorBase {
public:
////////////////////////////////////////////////////////////////////////////////
// PyTorch-specific things
////////////////////////////////////////////////////////////////////////////////
torch::Tensor tensor_;
/// Construct VectorBase which is an interface to an existing torch::Tensor
/// object.
VectorBase(torch::Tensor tensor);
////////////////////////////////////////////////////////////////////////////////
// Kaldi-compatible methods
////////////////////////////////////////////////////////////////////////////////
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L42-L43
void SetZero() {
Set(0);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L48-L49
void Set(Real f) {
tensor_.index_put_({"..."}, f);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L62-L63
inline MatrixIndexT Dim() const {
return tensor_.numel();
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L68-L69
inline Real* Data() {
return data_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L71-L72
inline const Real* Data() const {
return data_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L74-L79
inline Real operator()(MatrixIndexT i) const {
return data_[i];
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L81-L86
inline Real& operator()(MatrixIndexT i) {
return tensor_.accessor<Real, 1>()[i];
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L88-L95
SubVector<Real> Range(const MatrixIndexT o, const MatrixIndexT l) {
return SubVector<Real>(*this, o, l);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L97-L105
const SubVector<Real> Range(const MatrixIndexT o, const MatrixIndexT l)
const {
return SubVector<Real>(*this, o, l);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L107-L108
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L226-L233
void CopyFromVec(const VectorBase<Real>& v) {
TORCH_INTERNAL_ASSERT(tensor_.sizes() == v.tensor_.sizes());
tensor_.copy_(v.tensor_);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L137-L139
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L816-L832
void ApplyFloor(Real floor_val, MatrixIndexT* floored_count = nullptr) {
auto index = tensor_ < floor_val;
auto tmp = tensor_.index_put_({index}, floor_val);
if (floored_count) {
*floored_count = index.sum().item().template to<MatrixIndexT>();
}
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L164-L165
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L449-L479
void ApplyPow(Real power) {
tensor_.pow_(power);
TORCH_INTERNAL_ASSERT(!tensor_.isnan().sum().item().template to<int32_t>());
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L181-L184
template <typename OtherReal>
void AddVec(const Real alpha, const VectorBase<OtherReal>& v) {
tensor_ += alpha * v.tensor_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L186-L187
void AddVec2(const Real alpha, const VectorBase<Real>& v) {
tensor_ += alpha * (v.tensor_.square());
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L196-L198
void AddMatVec(
const Real alpha,
const MatrixBase<Real>& M,
const MatrixTransposeType trans,
const VectorBase<Real>& v,
const Real beta) { // **beta previously defaulted to 0.0**
auto mat = M.tensor_;
if (trans == kTrans) {
mat = mat.transpose(1, 0);
}
tensor_.addmv_(mat, v.tensor_, beta, alpha);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L221-L222
void MulElements(const VectorBase<Real>& v) {
tensor_ *= v.tensor_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L233-L234
void Add(Real c) {
tensor_ += c;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L236-L239
void AddVecVec(
Real alpha,
const VectorBase<Real>& v,
const VectorBase<Real>& r,
Real beta) {
tensor_ = beta * tensor_ + alpha * v.tensor_ * r.tensor_;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L246-L247
void Scale(Real alpha) {
tensor_ *= alpha;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L305-L306
Real Min() const {
if (tensor_.numel()) {
return tensor_.min().item().template to<Real>();
}
return std::numeric_limits<Real>::infinity();
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L308-L310
Real Min(MatrixIndexT* index) const {
TORCH_INTERNAL_ASSERT(tensor_.numel());
torch::Tensor value, ind;
std::tie(value, ind) = tensor_.min(0);
*index = ind.item().to<MatrixIndexT>();
return value.item().to<Real>();
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L312-L313
Real Sum() const {
return tensor_.sum().item().template to<Real>();
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L320-L321
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L718-L736
void AddRowSumMat(Real alpha, const MatrixBase<Real>& M, Real beta = 1.0) {
Vector<Real> ones(M.NumRows());
ones.Set(1.0);
this->AddMatVec(alpha, M, kTrans, ones, beta);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L323-L324
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L738-L757
void AddColSumMat(Real alpha, const MatrixBase<Real>& M, Real beta = 1.0) {
Vector<Real> ones(M.NumCols());
ones.Set(1.0);
this->AddMatVec(alpha, M, kNoTrans, ones, beta);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L326-L330
void AddDiagMat2(
Real alpha,
const MatrixBase<Real>& M,
MatrixTransposeType trans = kNoTrans,
Real beta = 1.0) {
auto mat = M.tensor_;
if (trans == kNoTrans) {
tensor_ =
beta * tensor_ + torch::diag(torch::mm(mat, mat.transpose(1, 0)));
} else {
tensor_ =
beta * tensor_ + torch::diag(torch::mm(mat.transpose(1, 0), mat));
}
}
protected:
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L362-L365
explicit VectorBase();
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L378-L379
Real* data_;
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L382
KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase);
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L385-L390
template <typename Real>
class Vector : public VectorBase<Real> {
public:
////////////////////////////////////////////////////////////////////////////////
// PyTorch-compatibility things
////////////////////////////////////////////////////////////////////////////////
/// Construct VectorBase which is an interface to an existing torch::Tensor
/// object.
Vector(torch::Tensor tensor) : VectorBase<Real>(tensor){};
////////////////////////////////////////////////////////////////////////////////
// Kaldi-compatible methods
////////////////////////////////////////////////////////////////////////////////
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L392-L393
Vector() : VectorBase<Real>(){};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L395-L399
explicit Vector(const MatrixIndexT s, MatrixResizeType resize_type = kSetZero)
: VectorBase<Real>() {
Resize(s, resize_type);
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L406-L410
// Note: unlike the original implementation, this is "explicit".
explicit Vector(const Vector<Real>& v)
: VectorBase<Real>(v.tensor_.clone()) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L412-L416
explicit Vector(const VectorBase<Real>& v)
: VectorBase<Real>(v.tensor_.clone()) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L434-L435
void Swap(Vector<Real>* other) {
auto tmp = VectorBase<Real>::tensor_;
this->tensor_ = other->tensor_;
other->tensor_ = tmp;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L444-L451
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L189-L223
void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero) {
auto& tensor_ = this->tensor_;
switch (resize_type) {
case kSetZero:
tensor_.resize_({length}).zero_();
break;
case kUndefined:
tensor_.resize_({length});
break;
case kCopyData:
auto tmp = tensor_;
auto tmp_numel = tensor_.numel();
tensor_.resize_({length}).zero_();
auto numel = Slice(length < tmp_numel ? length : tmp_numel);
tensor_.index_put_({numel}, tmp.index({numel}));
break;
}
// data_ptr<Real>() causes compiler error
this->data_ = static_cast<Real*>(tensor_.data_ptr());
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L463-L468
Vector<Real>& operator=(const VectorBase<Real>& other) {
Resize(other.Dim(), kUndefined);
this->CopyFromVec(other);
return *this;
}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L482-L485
template <typename Real>
class SubVector : public VectorBase<Real> {
public:
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L487-L499
SubVector(
const VectorBase<Real>& t,
const MatrixIndexT origin,
const MatrixIndexT length)
: VectorBase<Real>(t.tensor_.index({Slice(origin, origin + length)})) {}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L524-L528
SubVector(const MatrixBase<Real>& matrix, MatrixIndexT row)
: VectorBase<Real>(matrix.tensor_.index({row})) {}
};
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L540-L543
template <typename Real>
std::ostream& operator<<(std::ostream& out, const VectorBase<Real>& v) {
out << v.tensor_;
return out;
}
// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L573-L575
template <typename Real>
Real VecVec(const VectorBase<Real>& v1, const VectorBase<Real>& v2) {
return torch::dot(v1.tensor_, v2.tensor_).item().template to<Real>();
}
} // namespace kaldi
#endif
Subproject commit 3eea37dd09b55064e6362216f7e9a60641f29f09
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment