"...text-generation-inference.git" did not exist on "08e91814180c5a737749f9deadfc45fd0968037a"
Commit eb823114 authored by Davis King's avatar Davis King
Browse files

Checking in changes to the matrix object that allow it to

factor expressions containing trans() operators.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402833
parent 0803ff8f
......@@ -217,10 +217,14 @@ namespace dlib
dest_exp& dest,
const EXP& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
)
{
matrix_assign_default(dest,src,alpha,add_to);
if (transpose == false)
matrix_assign_default(dest,src,alpha,add_to);
else
matrix_assign_default(dest,trans(src),alpha,add_to);
}
// If we know this is a matrix multiply then apply the
......@@ -231,7 +235,8 @@ namespace dlib
dest_exp& dest,
const matrix_multiply_exp<EXP1,EXP2>& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
)
{
// At some point I need to improve the default (i.e. non BLAS) matrix
......@@ -239,15 +244,15 @@ namespace dlib
if (alpha == static_cast<typename src_exp::type>(1))
{
if (add_to)
{
default_matrix_multiply(dest, src.lhs, src.rhs);
}
else
if (add_to == false)
{
zero_matrix(dest);
default_matrix_multiply(dest, src.lhs, src.rhs);
}
if (transpose == false)
default_matrix_multiply(dest, src.lhs, src.rhs);
else
default_matrix_multiply(dest, trans(src.rhs), trans(src.lhs));
}
else
{
......@@ -255,13 +260,26 @@ namespace dlib
{
typename dest_exp::matrix_type temp(dest.nr(),dest.nc());
zero_matrix(temp);
default_matrix_multiply(temp, src.lhs, src.rhs);
if (transpose == false)
default_matrix_multiply(temp, src.lhs, src.rhs);
else
{
default_matrix_multiply(temp, trans(src.rhs), trans(src.lhs));
cout << "\ndo default mul" << endl;
}
matrix_assign_default(dest,temp, alpha,true);
}
else
{
zero_matrix(dest);
default_matrix_multiply(dest, src.lhs, src.rhs);
if (transpose == false)
default_matrix_multiply(dest, src.lhs, src.rhs);
else
default_matrix_multiply(dest, trans(src.rhs), trans(src.lhs));
matrix_assign_default(dest,dest, alpha, false);
}
}
......@@ -281,7 +299,8 @@ namespace dlib
dest_exp& dest, \
const src_exp& src, \
typename src_exp::type alpha, \
bool add_to \
bool add_to, \
bool transpose \
) { \
typedef typename dest_exp::type T;
......@@ -301,7 +320,8 @@ namespace dlib
dest_exp& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
);
/*!
requires
......@@ -318,7 +338,8 @@ namespace dlib
dest_exp& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
);
/*!
requires
......@@ -335,7 +356,26 @@ namespace dlib
dest_exp& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
);
/*!
requires
- src.aliases(dest) == false
- dest.nr() == src.nr()
- dest.nc() == src.nc()
!*/
template <
typename dest_exp,
typename src_exp
>
void matrix_assign_blas_proxy (
dest_exp& dest,
const matrix_unary_exp<src_exp,op_trans>& src,
typename src_exp::type alpha,
bool add_to,
bool transpose
);
/*!
requires
......@@ -352,7 +392,8 @@ namespace dlib
dest_exp& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
);
/*!
requires
......@@ -435,10 +476,11 @@ namespace dlib
dest_exp& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
)
{
matrix_assign_blas_helper<dest_exp,src_exp>::assign(dest,src,alpha,add_to);
matrix_assign_blas_helper<dest_exp,src_exp>::assign(dest,src,alpha,add_to, transpose);
}
// ------------------------------------------------------------------------------------
......@@ -451,17 +493,22 @@ namespace dlib
dest_exp& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
)
{
if (has_matrix_multiply<src_exp>::value || has_matrix_multiply<src_exp2>::value)
{
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign_blas_proxy(dest, src.rhs, alpha, true);
cout << "\n1 trans: " << transpose << " \n";
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to, transpose);
matrix_assign_blas_proxy(dest, src.rhs, alpha, true, transpose);
}
else
{
matrix_assign_default(dest, src, alpha, add_to);
if (transpose == false)
matrix_assign_default(dest, src, alpha, add_to);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
}
}
......@@ -475,10 +522,28 @@ namespace dlib
dest_exp& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
)
{
matrix_assign_blas_proxy(dest, src.m, alpha*src.s, add_to);
matrix_assign_blas_proxy(dest, src.m, alpha*src.s, add_to, transpose);
}
// ------------------------------------------------------------------------------------
template <
typename dest_exp,
typename src_exp
>
void matrix_assign_blas_proxy (
dest_exp& dest,
const matrix_unary_exp<src_exp,op_trans>& src,
typename src_exp::type alpha,
bool add_to,
bool transpose
)
{
matrix_assign_blas_proxy(dest, src.m, alpha, add_to, !transpose);
}
// ------------------------------------------------------------------------------------
......@@ -491,18 +556,22 @@ namespace dlib
dest_exp& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
bool add_to,
bool transpose
)
{
if (has_matrix_multiply<src_exp>::value || has_matrix_multiply<src_exp2>::value)
{
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign_blas_proxy(dest, src.rhs, -alpha, true);
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to, transpose);
matrix_assign_blas_proxy(dest, src.rhs, -alpha, true, transpose);
}
else
{
matrix_assign_default(dest, src, alpha, add_to);
if (transpose == false)
matrix_assign_default(dest, src, alpha, add_to);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
}
}
......@@ -525,12 +594,12 @@ namespace dlib
if (src.aliases(dest))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_blas_proxy(temp,src,1,false, false);
temp.swap(dest);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
matrix_assign_blas_proxy(dest,src,1,false, false);
}
}
......@@ -548,12 +617,12 @@ namespace dlib
if (src.aliases(dest.m))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_blas_proxy(temp,src,1,false, false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
matrix_assign_blas_proxy(dest,src,1,false, false);
}
}
......@@ -571,12 +640,12 @@ namespace dlib
if (src.aliases(dest.m))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_blas_proxy(temp,src,1,false, false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
matrix_assign_blas_proxy(dest,src,1,false, false);
}
}
......@@ -594,12 +663,12 @@ namespace dlib
if (src.aliases(dest.m))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_blas_proxy(temp,src,1,false, false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
matrix_assign_blas_proxy(dest,src,1,false, false);
}
}
......@@ -621,12 +690,12 @@ namespace dlib
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, 1, true);
matrix_assign_blas_proxy(dest, src.rhs, 1, true, false);
}
else
{
matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, 1, true);
matrix_assign_blas_proxy(temp, src.rhs, 1, true, false);
temp.swap(dest);
}
}
......@@ -667,12 +736,12 @@ namespace dlib
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, -1, true);
matrix_assign_blas_proxy(dest, src.rhs, -1, true, false);
}
else
{
matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, -1, true);
matrix_assign_blas_proxy(temp, src.rhs, -1, true, false);
temp.swap(dest);
}
}
......
......@@ -326,8 +326,6 @@ namespace dlib
//cout << "BLAS GEMM: m*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
......@@ -340,7 +338,11 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -364,7 +366,11 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -388,18 +394,20 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(m)*trans(m))
{
//cout << "BLAS GEMM: trans(m)*trans(m)" << endl;
cout << "BLAS GEMM: trans(m)*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
......@@ -412,7 +420,10 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -438,7 +449,11 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -462,7 +477,10 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -486,7 +504,10 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
......@@ -795,7 +816,11 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -818,7 +843,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -841,7 +869,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -864,7 +895,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
......@@ -891,7 +925,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -914,7 +951,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -937,7 +977,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -960,7 +1003,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
......
......@@ -26,7 +26,7 @@ namespace
logger dlog("test.matrix3");
const double eps_mul = 200000;
const double eps_mul = 500000;
template <typename T, typename U>
void check_equal (
......@@ -101,7 +101,7 @@ namespace
dlib::rand::float_1a rnd;
matrix<type> a(rows,cols), temp, temp2;
matrix<type> a(rows,cols), temp, temp2, temp3;
for (int i = 0; i < 6; ++i)
{
......@@ -173,51 +173,99 @@ namespace
// GEMM tests
dlog << LTRACE << "1.1";
check_equal(tmp(at*a), at*a);
check_equal(tmp(trans(at*a)), trans(at*a));
check_equal(tmp(2.4*trans(4*trans(at*a) + at*3*a)), 2.4*trans(4*trans(at*a) + at*3*a));
dlog << LTRACE << "1.2";
check_equal(tmp(trans(a)*a), trans(a)*a);
check_equal(tmp(trans(trans(a)*a)), trans(trans(a)*a));
dlog << LTRACE << "1.3";
check_equal(tmp(at*trans(at)), at*trans(at));
check_equal(tmp(trans(at*trans(at))), trans(at*trans(at)));
dlog << LTRACE << "1.4";
check_equal(tmp(trans(at)*trans(a)), a*at);
check_equal(tmp(trans(trans(at)*trans(a))), trans(a*at));
dlog << LTRACE << "1.5";
print_spinner();
c_check_equal(tmp(conj(trans(c_a))*c_a), trans(conj(c_a))*c_a);
c_check_equal(tmp(trans(conj(trans(c_a))*c_a)), trans(trans(conj(c_a))*c_a));
dlog << LTRACE << "1.6";
c_check_equal(tmp(c_at*trans(conj(c_at))), c_at*conj(trans(c_at)));
c_check_equal(tmp(trans(c_at*trans(conj(c_at)))), trans(c_at*conj(trans(c_at))));
dlog << LTRACE << "1.7";
c_check_equal(tmp(conj(trans(c_at))*trans(conj(c_a))), conj(trans(c_at))*trans(conj(c_a)));
c_check_equal(tmp(trans(conj(trans(c_at))*trans(conj(c_a)))), trans(conj(trans(c_at))*trans(conj(c_a))));
dlog << LTRACE << "1.8";
check_equal(tmp(a*trans(rowm(a,1))) , a*trans(rowm(a,1)));
check_equal(tmp(a*colm(at,1)) , a*colm(at,1));
check_equal(tmp(subm(a,1,1,2,2)*subm(a,1,2,2,2)), subm(a,1,1,2,2)*subm(a,1,2,2,2));
temp = at*a;
temp2 = temp;
dlog << LTRACE << "1.9";
check_equal(tmp(trans(a*trans(rowm(a,1)))) , trans(a*trans(rowm(a,1))));
dlog << LTRACE << "1.10";
check_equal(tmp(trans(a*colm(at,1))) , trans(a*colm(at,1)));
dlog << LTRACE << "1.11";
check_equal(tmp(trans(subm(a,1,1,2,2)*subm(a,1,2,2,2))), trans(subm(a,1,1,2,2)*subm(a,1,2,2,2)));
dlog << LTRACE << "1.12";
temp += 3.5*at*a;
assign_no_blas(temp2, temp2 + 3.5*at*a);
check_equal(temp, temp2);
{
temp = at*a;
temp2 = temp;
temp -= at*3.5*a;
assign_no_blas(temp2, temp2 - at*3.5*a);
check_equal(temp, temp2);
temp += 3.5*at*a;
assign_no_blas(temp2, temp2 + 3.5*at*a);
check_equal(temp, temp2);
temp = temp + 4*at*a;
assign_no_blas(temp2, temp2 + 4*at*a);
check_equal(temp, temp2);
temp -= at*3.5*a;
assign_no_blas(temp2, temp2 - at*3.5*a);
check_equal(temp, temp2);
temp = temp - 2.4*at*a;
assign_no_blas(temp2, temp2 - 2.4*at*a);
check_equal(temp, temp2);
temp = temp + 4*at*a;
assign_no_blas(temp2, temp2 + 4*at*a);
check_equal(temp, temp2);
temp = temp - 2.4*at*a;
assign_no_blas(temp2, temp2 - 2.4*at*a);
check_equal(temp, temp2);
}
dlog << LTRACE << "1.13";
{
temp = trans(at*a);
temp2 = temp;
temp3 = temp;
dlog << LTRACE << "1.14";
temp += trans(3.5*at*a);
assign_no_blas(temp2, temp2 + trans(3.5*at*a));
check_equal(temp, temp2);
dlog << LTRACE << "1.15";
temp -= trans(at*3.5*a);
assign_no_blas(temp2, temp2 - trans(at*3.5*a));
check_equal(temp, temp2);
dlog << LTRACE << "1.16";
temp = trans(temp + 4*at*a);
assign_no_blas(temp3, trans(temp2 + 4*at*a));
check_equal(temp, temp3);
temp2 = temp;
dlog << LTRACE << "1.17";
temp = trans(temp - 2.4*at*a);
assign_no_blas(temp3, trans(temp2 - 2.4*at*a));
check_equal(temp, temp3);
}
dlog << LTRACE << "1.18";
// GEMV tests
check_equal(tmp(a*cv4), a*cv4);
check_equal(tmp(trans(a*cv4)), trans(a*cv4));
check_equal(tmp(rv3*a), rv3*a);
check_equal(tmp(trans(cv4)*at), trans(cv4)*at);
check_equal(tmp(a*trans(rv4)), a*trans(rv4));
check_equal(tmp(trans(a*trans(rv4))), trans(a*trans(rv4)));
check_equal(tmp(trans(a)*cv3), trans(a)*cv3);
check_equal(tmp(rv4*trans(a)), rv4*trans(a));
......@@ -291,21 +339,76 @@ namespace
// Test BLAS GER
temp.set_size(cols,cols);
set_all_elements(temp,3);
temp2 = temp;
{
temp.set_size(cols,cols);
set_all_elements(temp,3);
temp2 = temp;
dlog << LTRACE << "8";
temp += cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8";
temp += cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8.3";
temp = temp + cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8.9";
dlog << LTRACE << "8.3";
temp = temp + cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8.9";
}
{
temp.set_size(cols,cols);
set_all_elements(temp,3);
temp2 = temp;
temp3 = 0;
dlog << LTRACE << "8.10";
temp += trans(cv4*rv4);
assign_no_blas(temp3, temp2 + trans(cv4*rv4));
check_equal(temp, temp3);
temp3 = 0;
dlog << LTRACE << "8.11";
temp2 = temp;
temp = trans(temp + cv4*rv4);
assign_no_blas(temp3, trans(temp2 + cv4*rv4));
check_equal(temp, temp3);
dlog << LTRACE << "8.12";
}
{
matrix<complex<type> > temp, temp2, temp3;
matrix<complex<type>,0,1 > cv4;
matrix<complex<type>,1,0 > rv4;
cv4.set_size(cols);
rv4.set_size(cols);
temp.set_size(cols,cols);
set_all_elements(temp,complex<type>(3,5));
temp(cols-1, cols-4) = 9;
temp2 = temp;
temp3.set_size(cols,cols);
temp3 = 0;
for (long i = 0; i < rv4.size(); ++i)
{
rv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
cv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
}
dlog << LTRACE << "8.13";
temp += trans(cv4*rv4);
assign_no_blas(temp3, temp2 + trans(cv4*rv4));
c_check_equal(temp, temp3);
temp3 = 0;
dlog << LTRACE << "8.14";
temp2 = temp;
temp = trans(temp + cv4*rv4);
assign_no_blas(temp3, trans(temp2 + cv4*rv4));
c_check_equal(temp, temp3);
dlog << LTRACE << "8.15";
}
......@@ -340,6 +443,7 @@ namespace
// Test DOT
check_equal( tmp(rv4*cv4), rv4*cv4);
check_equal( tmp(trans(rv4*cv4)), trans(rv4*cv4));
check_equal( tmp(trans(cv4)*trans(rv4)), trans(cv4)*trans(rv4));
check_equal( tmp(rv4*3.9*cv4), rv4*3.9*cv4);
check_equal( tmp(trans(cv4)*3.9*trans(rv4)), trans(cv4)*3.9*trans(rv4));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment