#include "matrix/kaldi-matrix.h" #include namespace { template void assert_matrix_shape(const torch::Tensor& tensor_); template <> void assert_matrix_shape(const torch::Tensor& tensor_) { TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2); TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32); TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); } template <> void assert_matrix_shape(const torch::Tensor& tensor_) { TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2); TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64); TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); } } // namespace namespace kaldi { template MatrixBase::MatrixBase(torch::Tensor tensor) : tensor_(tensor) { assert_matrix_shape(tensor_); }; template class Matrix; template class Matrix; template class MatrixBase; template class MatrixBase; template class SubMatrix; template class SubMatrix; } // namespace kaldi