"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "79940695def96c005402301bbf1b22460a44ec33"
Commit c596cebb authored by Davis King's avatar Davis King
Browse files

Improved the BLAS binding system. It should now break expressions up into

their proper basic BLAS function calls.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402764
parent 6b5f9a8a
...@@ -24,6 +24,39 @@ namespace dlib ...@@ -24,6 +24,39 @@ namespace dlib
namespace blas_bindings namespace blas_bindings
{ {
// ------------------------------------------------------------------------------------
// This template struct is used to tell us if a matrix expression contains a matrix multiply.
template <typename T>
struct has_matrix_multiply
{
const static bool value = false;
};
template <typename T, typename U>
struct has_matrix_multiply<matrix_multiply_exp<T,U> >
{ const static bool value = true; };
template <typename T, typename U>
struct has_matrix_multiply<matrix_add_exp<T,U> >
{ const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
template <typename T, typename U>
struct has_matrix_multiply<matrix_subtract_exp<T,U> >
{ const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
template <typename T, bool Tb>
struct has_matrix_multiply<matrix_mul_scal_exp<T,Tb> >
{ const static bool value = has_matrix_multiply<T>::value; };
template <typename T>
struct has_matrix_multiply<matrix_div_scal_exp<T> >
{ const static bool value = has_matrix_multiply<T>::value; };
template <typename T, typename OP>
struct has_matrix_multiply<matrix_unary_exp<T,OP> >
{ const static bool value = has_matrix_multiply<T>::value; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template <typename T, typename U> template <typename T, typename U>
...@@ -110,10 +143,12 @@ namespace dlib ...@@ -110,10 +143,12 @@ namespace dlib
template <typename EXP> template <typename EXP>
static void assign ( static void assign (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const EXP& src const EXP& src,
typename src_exp::type alpha,
bool add_to
) )
{ {
matrix_assign_default(dest,src); matrix_assign_default(dest,src,alpha,add_to);
} }
// If we know this is a matrix multiply then apply the // If we know this is a matrix multiply then apply the
...@@ -122,99 +157,356 @@ namespace dlib ...@@ -122,99 +157,356 @@ namespace dlib
template <typename EXP1, typename EXP2> template <typename EXP1, typename EXP2>
static void assign ( static void assign (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const matrix_multiply_exp<EXP1,EXP2>& src const matrix_multiply_exp<EXP1,EXP2>& src,
typename src_exp::type alpha,
bool add_to
) )
{ {
set_all_elements(dest,0); // At some point I need to improve the default (i.e. non BLAS) matrix
default_matrix_multiply(dest, src.lhs, src.rhs); // multiplication algorithm...
}
template <typename EXP1, typename EXP2> if (alpha == 1)
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_multiply_exp<EXP1,EXP2> >& src
)
{
if (&dest == &src.lhs)
{ {
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs); if (add_to)
{
default_matrix_multiply(dest, src.lhs, src.rhs);
}
else
{
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
}
} }
else else
{ {
dest = src.lhs; if (add_to)
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs); {
matrix<T,NR,NC,MM,L> temp(dest);
default_matrix_multiply(temp, src.lhs, src.rhs);
dest = alpha*temp;
}
else
{
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
dest = alpha*dest;
}
} }
} }
};
template <typename EXP1, typename EXP2> // This is a macro to help us add overloads for the matrix_assign_blas_helper template.
static void assign ( // Using this macro it is easy to add overloads for arbitrary matrix expressions.
matrix<T,NR,NC,MM,L>& dest, #define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_add_exp<EXP1,EXP2> >& src template <typename T> struct BOOST_JOIN(blas,__LINE__) \
) { const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \
matrix<T,NR,NC,MM,dest_layout>& dest, \
const src_exp& src, \
typename src_exp::type alpha, \
bool add_to \
) {
#define DLIB_END_BLAS_BINDING }};
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------- Forward Declarations -------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, bool Sb
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
);
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
);
/*!
This function catches the expressions of the form:
M = M + exp;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the += matrix operator.
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
);
/*!
This function catches the expressions of the form:
M = M - exp;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the -= matrix operator.
!*/
// End of forward declarations for overloaded matrix_assign_blas functions
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
)
{
matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src,alpha,add_to);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
)
{
if (src_exp::cost > 9 || src_exp2::cost > 9)
{ {
if (EXP1::cost > 50 || EXP2::cost > 5) matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
{ matrix_assign_blas_proxy(dest, src.rhs, alpha, true);
matrix_assign(dest, src.lhs + src.rhs.lhs);
matrix_assign(dest, src.lhs + src.rhs.rhs);
}
else
{
matrix_assign_default(dest,src);
}
} }
else
{
matrix_assign_default(dest, src, alpha, add_to);
}
}
// ------------------------------------------------------------------------------------
template <typename EXP2> template <
static void assign ( typename T, long NR, long NC, typename MM, typename L,
matrix<T,NR,NC,MM,L>& dest, typename src_exp, bool Sb
const matrix_add_exp<matrix<T,NR,NC,MM,L>,EXP2>& src >
) void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
)
{
matrix_assign_blas_proxy(dest, src.m, alpha*src.s, add_to);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
)
{
if (src_exp::cost > 9 || src_exp2::cost > 9)
{
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign_blas_proxy(dest, src.rhs, -alpha, true);
}
else
{
matrix_assign_default(dest, src, alpha, add_to);
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// Once we get into this function it means that we are dealing with a matrix of float,
// double, complex<float>, or complex<double> and the src_exp contains at least one
// matrix multiply.
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
if (src.aliases(dest))
{
matrix<T,NR,NC,MM,L> temp;
matrix_assign_blas_proxy(temp,src,1,false);
temp.swap(dest);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
}
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
)
{
if (src_exp::cost > 5)
{ {
if (EXP2::cost > 50 && &dest != &src.lhs) if (src.rhs.aliases(dest) == false)
{ {
dest = src.lhs; if (&src.lhs != &dest)
matrix_assign(dest, dest + src.rhs); {
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, 1, true);
} }
else else
{ {
matrix_assign_default(dest,src); matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, 1, true);
temp.swap(dest);
} }
} }
else
{
matrix_assign_default(dest,src);
}
}
// ------------------------------------------------------------------------------------
template <
template <typename EXP1, typename EXP2> typename T, long NR, long NC, typename MM, typename L,
static void assign ( typename src_exp
matrix<T,NR,NC,MM,L>& dest, >
const matrix_add_exp<EXP1,EXP2>& src void matrix_assign_blas (
) matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
)
{
if (src_exp::cost > 5)
{ {
if (EXP1::cost > 50 || EXP2::cost > 50) if (src.rhs.aliases(dest) == false)
{ {
matrix_assign(dest,src.lhs); if (&src.lhs != &dest)
matrix_assign(dest, dest + src.rhs); {
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, -1, true);
} }
else else
{ {
matrix_assign_default(dest,src); matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, -1, true);
temp.swap(dest);
} }
} }
}; else
{
// This is a macro to help us add overloads for the matrix_assign_blas_helper template. matrix_assign_default(dest,src);
// Using this macro it is easy to add overloads for arbitrary matrix expressions. }
#define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \ }
template <typename T> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \
matrix<T,NR,NC,MM,dest_layout>& dest, \
const src_exp& src \
) {
#define DLIB_END_BLAS_BINDING }};
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
} // end of namespace blas_bindings } // end of namespace blas_bindings
...@@ -227,13 +519,14 @@ namespace dlib ...@@ -227,13 +519,14 @@ namespace dlib
inline typename enable_if_c<(is_same_type<T,float>::value || inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value || is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value || is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big ( >::type matrix_assign_big (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const src_exp& src const src_exp& src
) )
{ {
blas_bindings::matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src); blas_bindings::matrix_assign_blas(dest,src);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -44,6 +44,12 @@ namespace dlib ...@@ -44,6 +44,12 @@ namespace dlib
EXP1& dest, EXP1& dest,
const EXP2& src const EXP2& src
) )
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- #dest == src
!*/
{ {
for (long r = 0; r < src.nr(); ++r) for (long r = 0; r < src.nr(); ++r)
{ {
...@@ -54,6 +60,83 @@ namespace dlib ...@@ -54,6 +60,83 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template <typename EXP1, typename EXP2>
inline static void matrix_assign_default (
EXP1& dest,
const EXP2& src,
typename EXP2::type alpha,
bool add_to
)
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- if (add_to == false) then
- #dest == alpha*src
- else
- #dest == dest + alpha*src
!*/
{
if (add_to)
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += src(r,c);
}
}
}
else if (alpha == -1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) -= src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += alpha*src(r,c);
}
}
}
}
else
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = alpha*src(r,c);
}
}
}
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -83,7 +166,6 @@ namespace dlib ...@@ -83,7 +166,6 @@ namespace dlib
- src.destructively_aliases(dest) == false - src.destructively_aliases(dest) == false
ensures ensures
- #dest == src - #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/ !*/
{ {
// Call src.ref() here so that the derived type of the matrix_exp shows // Call src.ref() here so that the derived type of the matrix_exp shows
...@@ -107,7 +189,6 @@ namespace dlib ...@@ -107,7 +189,6 @@ namespace dlib
- src.destructively_aliases(dest) == false - src.destructively_aliases(dest) == false
ensures ensures
- #dest == src - #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/ !*/
{ {
matrix_assign_default(dest,src.ref()); matrix_assign_default(dest,src.ref());
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include "cblas.h" #include "cblas.h"
#endif #endif
#include <iostream>
namespace dlib namespace dlib
{ {
...@@ -219,45 +221,18 @@ namespace dlib ...@@ -219,45 +221,18 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs(0,0); const T* A = &src.lhs(0,0);
const int lda = src.lhs.nc(); const int lda = src.lhs.nc();
const T* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const T beta = 0; const T beta = add_to?1:0;
T* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout,rm + rm*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.rhs.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm)
...@@ -268,94 +243,18 @@ namespace dlib ...@@ -268,94 +243,18 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs.m(0,0); const T* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc(); const int lda = src.lhs.m.nc();
const T* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const T beta = 0; const T beta = add_to?1:0;
T* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + s*trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const T alpha = src.rhs.s;
const T* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const T* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, s*trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const T alpha = src.s;
const T* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const T* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const T beta = 0;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
#endif // DLIB_USE_BLAS #endif // DLIB_USE_BLAS
} }
......
...@@ -23,24 +23,25 @@ namespace dlib ...@@ -23,24 +23,25 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < /*! This file defines the default_matrix_multiply() function. It is a function
typename matrix_dest_type, that conforms to the following definition:
typename EXP1,
typename EXP2 template <
> typename matrix_dest_type,
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true && ma::matrix_is_vector<EXP2>::value == true>::type typename EXP1,
default_matrix_multiply ( typename EXP2
matrix_dest_type& dest, >
const EXP1& lhs, void default_matrix_multiply (
const EXP2& rhs matrix_dest_type& dest,
); const EXP1& lhs,
/*! const EXP2& rhs
requires );
- (lhs*rhs).destructively_aliases(dest) == false requires
- dest.nr() == (lhs*rhs).nr() - (lhs*rhs).destructively_aliases(dest) == false
- dest.nc() == (lhs*rhs).nc() - dest.nr() == (lhs*rhs).nr()
ensures - dest.nc() == (lhs*rhs).nc()
- #dest == dest + lhs*rhs ensures
- #dest == dest + lhs*rhs
!*/ !*/
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment