"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "8e37055993c423ba11dd7b470ec959381db796c4"
Commit 8c41c850 authored by Davis King's avatar Davis King
Browse files

Made many of the mat() converters bind the resulting matrix expressions into

the BLAS bindings.
parent b5038f78
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "matrix_assign_fwd.h" #include "matrix_assign_fwd.h"
#include "matrix_default_mul.h" #include "matrix_default_mul.h"
#include "matrix_conj_trans.h" #include "matrix_conj_trans.h"
#include "matrix_mat.h"
namespace dlib namespace dlib
{ {
...@@ -159,6 +160,29 @@ namespace dlib ...@@ -159,6 +160,29 @@ namespace dlib
const static int value = general_matrix; const static int value = general_matrix;
}; };
template < typename T, typename MM >
struct matrix_type_id<matrix_op<op_array2d_to_mat<array2d<T,MM> > > >
{ const static int value = general_matrix; };
template < typename T, typename MM >
struct matrix_type_id<matrix_op<op_array_to_mat<array<T,MM> > > >
{ const static int value = column_matrix; };
template < typename value_type, typename alloc >
struct matrix_type_id<matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > > >
{ const static int value = column_matrix; };
template < typename value_type, typename alloc >
struct matrix_type_id<matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > > >
{ const static int value = column_matrix; };
template < typename T >
struct matrix_type_id<matrix_op<op_pointer_to_col_vect<T> > >
{ const static int value = column_matrix; };
template < typename T >
struct matrix_type_id<matrix_op<op_pointer_to_mat<T> > >
{ const static int value = general_matrix; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template <typename T, typename U> template <typename T, typename U>
......
...@@ -408,8 +408,34 @@ namespace dlib ...@@ -408,8 +408,34 @@ namespace dlib
template <typename T, long NR, long NC, typename MM> template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); } int get_ld (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
template <typename T, typename MM>
int get_ld (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& m) { return m.nc(); }
template <typename T, typename MM>
int get_ld (const matrix_op<op_array_to_mat<array<T,MM> > >& m) { return m.nc(); }
template < typename value_type, typename alloc >
int get_ld (const matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > >& m) { return m.nc(); }
template < typename value_type, typename alloc >
int get_ld (const matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > >& m) { return m.nc(); }
template <typename T>
int get_ld (const matrix_op<op_pointer_to_col_vect<T> >& m) { return m.nc(); }
template <typename T>
int get_ld (const matrix_op<op_pointer_to_mat<T> >& m) { return m.nc(); }
// -------- // --------
template <typename T, typename MM>
int get_inc (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& ) { return 1; }
template <typename T, typename MM>
int get_inc (const matrix_op<op_array_to_mat<array<T,MM> > >& ) { return 1; }
template < typename value_type, typename alloc >
int get_inc (const matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > >& ) { return 1; }
template < typename value_type, typename alloc >
int get_inc (const matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > >& ) { return 1; }
template <typename T>
int get_inc (const matrix_op<op_pointer_to_col_vect<T> >& ) { return 1; }
template <typename T>
int get_inc (const matrix_op<op_pointer_to_mat<T> >& ) { return 1; }
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; } int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; }
...@@ -522,6 +548,19 @@ namespace dlib ...@@ -522,6 +548,19 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_sub_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); } T* get_ptr (assignable_sub_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, typename MM>
const T* get_ptr (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& m) { return &m.op.array[0][0]; }
template <typename T, typename MM>
const T* get_ptr (const matrix_op<op_array_to_mat<array<T,MM> > >& m) { return &m.op.vect[0]; }
template < typename T, typename alloc >
const T* get_ptr (const matrix_op<op_std_vect_to_mat<std::vector<T,alloc> > >& m) { return &m.op.vect[0]; }
template < typename T, typename alloc >
const T* get_ptr (const matrix_op<op_std_vect_to_mat<std_vector_c<T,alloc> > >& m) { return &m.op.vect[0]; }
template <typename T>
const T* get_ptr (const matrix_op<op_pointer_to_col_vect<T> >& m) { return m.op.ptr; }
template <typename T>
const T* get_ptr (const matrix_op<op_pointer_to_mat<T> >& m) { return m.op.ptr; }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -42,6 +42,61 @@ namespace ...@@ -42,6 +42,61 @@ namespace
) )
{} {}
void test_mat_bindings()
{
using namespace dlib;
using namespace dlib::blas_bindings;
matrix<double,1,0> rv(10);
matrix<double,0,1> cv(10);
double val;
rv = 1; cv = 1;
counter_dot() = 0;
val = rv*cv;
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
rv = 1; cv = 1;
counter_dot() = 0;
val = rv*mat(&cv(0),cv.size());
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
rv = 1; cv = 1;
counter_dot() = 0;
val = trans(mat(&rv(0),rv.size()))*mat(&cv(0),cv.size());
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
std::vector<double> sv(10,1);
rv = 1;
counter_dot() = 0;
val = trans(mat(&rv(0),rv.size()))*mat(sv);
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = trans(mat(sv))*mat(sv);
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
std_vector_c<double> svc(10,1);
counter_dot() = 0;
val = trans(mat(svc))*mat(svc);
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
dlib::array<double> arr(10);
for (unsigned int i = 0; i < arr.size(); ++i)
arr[i] = 1;
counter_dot() = 0;
val = trans(mat(arr))*mat(arr);
DLIB_TEST(val == 10);
DLIB_TEST(counter_dot() == 1);
}
template <typename matrix_type, typename cv_type, typename rv_type> template <typename matrix_type, typename cv_type, typename rv_type>
void test_dot_stuff( void test_dot_stuff(
matrix_type& m, matrix_type& m,
...@@ -238,6 +293,8 @@ namespace ...@@ -238,6 +293,8 @@ namespace
} }
test_mat_bindings();
print_spinner(); print_spinner();
} }
}; };
......
...@@ -258,6 +258,21 @@ namespace ...@@ -258,6 +258,21 @@ namespace
test_gemm_stuff_conj(c); test_gemm_stuff_conj(c);
} }
{
using namespace dlib;
using namespace dlib::blas_bindings;
array2d<double> a(100,100);
array2d<double> b(100,100);
matrix<double> c;
counter_gemm() = 0;
c = mat(a)*mat(b);
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
c = trans(2*mat(a)*mat(b));
DLIB_TEST(counter_gemm() == 1);
}
print_spinner(); print_spinner();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment