#include "matrix/kaldi-vector.h" #include #include "matrix/kaldi-matrix.h" namespace { template void assert_vector_shape(const torch::Tensor& tensor_); template <> void assert_vector_shape(const torch::Tensor& tensor_) { TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1); TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32); TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); } template <> void assert_vector_shape(const torch::Tensor& tensor_) { TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1); TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64); TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); } } // namespace namespace kaldi { template VectorBase::VectorBase(torch::Tensor tensor) : tensor_(tensor), data_(tensor.data_ptr()) { assert_vector_shape(tensor_); }; template VectorBase::VectorBase() : VectorBase(torch::empty({0})) {} template class Vector; template class Vector; template class VectorBase; template class VectorBase; } // namespace kaldi