"vscode:/vscode.git/clone" did not exist on "761546315cd08e1a4948eb398dfc38dcec0dc432"
kaldi-vector.cc 1.16 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include "matrix/kaldi-vector.h"
#include <torch/torch.h>
#include "matrix/kaldi-matrix.h"

namespace {

template <typename Real>
void assert_vector_shape(const torch::Tensor& tensor_);

template <>
void assert_vector_shape<float>(const torch::Tensor& tensor_) {
  TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1);
  TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32);
  TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}

template <>
void assert_vector_shape<double>(const torch::Tensor& tensor_) {
  TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1);
  TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64);
  TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU.");
}

} // namespace

namespace kaldi {

template <typename Real>
VectorBase<Real>::VectorBase(torch::Tensor tensor)
    : tensor_(tensor), data_(tensor.data_ptr<Real>()) {
  assert_vector_shape<Real>(tensor_);
};

template <typename Real>
VectorBase<Real>::VectorBase() : VectorBase<Real>(torch::empty({0})) {}

template class Vector<float>;
template class Vector<double>;
template class VectorBase<float>;
template class VectorBase<double>;

} // namespace kaldi