Unverified Commit 4c4419e0 authored by Lucas Beyer's avatar Lucas Beyer Committed by GitHub
Browse files

Merge pull request #108 from STulling/master

Fix Windows MSVC install by updating Eigen Library
parents 4d5343c3 13b115ab
...@@ -15,7 +15,7 @@ namespace Eigen { ...@@ -15,7 +15,7 @@ namespace Eigen {
namespace internal { namespace internal {
// pack a selfadjoint block diagonal for use with the gebp_kernel // pack a selfadjoint block diagonal for use with the gebp_kernel
template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder> template<typename Scalar, typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
struct symm_pack_lhs struct symm_pack_lhs
{ {
template<int BlockRows> inline template<int BlockRows> inline
...@@ -30,9 +30,9 @@ struct symm_pack_lhs ...@@ -30,9 +30,9 @@ struct symm_pack_lhs
for(Index k=i; k<i+BlockRows; k++) for(Index k=i; k<i+BlockRows; k++)
{ {
for(Index w=0; w<h; w++) for(Index w=0; w<h; w++)
blockA[count++] = conj(lhs(k, i+w)); // transposed blockA[count++] = numext::conj(lhs(k, i+w)); // transposed
blockA[count++] = real(lhs(k,k)); // real (diagonal) blockA[count++] = numext::real(lhs(k,k)); // real (diagonal)
for(Index w=h+1; w<BlockRows; w++) for(Index w=h+1; w<BlockRows; w++)
blockA[count++] = lhs(i+w, k); // normal blockA[count++] = lhs(i+w, k); // normal
...@@ -41,34 +41,41 @@ struct symm_pack_lhs ...@@ -41,34 +41,41 @@ struct symm_pack_lhs
// transposed copy // transposed copy
for(Index k=i+BlockRows; k<cols; k++) for(Index k=i+BlockRows; k<cols; k++)
for(Index w=0; w<BlockRows; w++) for(Index w=0; w<BlockRows; w++)
blockA[count++] = conj(lhs(k, i+w)); // transposed blockA[count++] = numext::conj(lhs(k, i+w)); // transposed
} }
void operator()(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows) void operator()(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows)
{ {
enum { PacketSize = packet_traits<Scalar>::size };
const_blas_data_mapper<Scalar,Index,StorageOrder> lhs(_lhs,lhsStride); const_blas_data_mapper<Scalar,Index,StorageOrder> lhs(_lhs,lhsStride);
Index count = 0; Index count = 0;
Index peeled_mc = (rows/Pack1)*Pack1; //Index peeled_mc3 = (rows/Pack1)*Pack1;
for(Index i=0; i<peeled_mc; i+=Pack1)
{ const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
pack<Pack1>(blockA, lhs, cols, i, count); const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
} const Index peeled_mc1 = Pack1>=1*PacketSize ? (rows/(1*PacketSize))*(1*PacketSize) : 0;
if(rows-peeled_mc>=Pack2) if(Pack1>=3*PacketSize)
{ for(Index i=0; i<peeled_mc3; i+=3*PacketSize)
pack<Pack2>(blockA, lhs, cols, peeled_mc, count); pack<3*PacketSize>(blockA, lhs, cols, i, count);
peeled_mc += Pack2;
} if(Pack1>=2*PacketSize)
for(Index i=peeled_mc3; i<peeled_mc2; i+=2*PacketSize)
pack<2*PacketSize>(blockA, lhs, cols, i, count);
if(Pack1>=1*PacketSize)
for(Index i=peeled_mc2; i<peeled_mc1; i+=1*PacketSize)
pack<1*PacketSize>(blockA, lhs, cols, i, count);
// do the same with mr==1 // do the same with mr==1
for(Index i=peeled_mc; i<rows; i++) for(Index i=peeled_mc1; i<rows; i++)
{ {
for(Index k=0; k<i; k++) for(Index k=0; k<i; k++)
blockA[count++] = lhs(i, k); // normal blockA[count++] = lhs(i, k); // normal
blockA[count++] = real(lhs(i, i)); // real (diagonal) blockA[count++] = numext::real(lhs(i, i)); // real (diagonal)
for(Index k=i+1; k<cols; k++) for(Index k=i+1; k<cols; k++)
blockA[count++] = conj(lhs(k, i)); // transposed blockA[count++] = numext::conj(lhs(k, i)); // transposed
} }
} }
}; };
...@@ -82,7 +89,8 @@ struct symm_pack_rhs ...@@ -82,7 +89,8 @@ struct symm_pack_rhs
Index end_k = k2 + rows; Index end_k = k2 + rows;
Index count = 0; Index count = 0;
const_blas_data_mapper<Scalar,Index,StorageOrder> rhs(_rhs,rhsStride); const_blas_data_mapper<Scalar,Index,StorageOrder> rhs(_rhs,rhsStride);
Index packet_cols = (cols/nr)*nr; Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
// first part: normal case // first part: normal case
for(Index j2=0; j2<k2; j2+=nr) for(Index j2=0; j2<k2; j2+=nr)
...@@ -91,91 +99,163 @@ struct symm_pack_rhs ...@@ -91,91 +99,163 @@ struct symm_pack_rhs
{ {
blockB[count+0] = rhs(k,j2+0); blockB[count+0] = rhs(k,j2+0);
blockB[count+1] = rhs(k,j2+1); blockB[count+1] = rhs(k,j2+1);
if (nr==4) if (nr>=4)
{ {
blockB[count+2] = rhs(k,j2+2); blockB[count+2] = rhs(k,j2+2);
blockB[count+3] = rhs(k,j2+3); blockB[count+3] = rhs(k,j2+3);
} }
if (nr>=8)
{
blockB[count+4] = rhs(k,j2+4);
blockB[count+5] = rhs(k,j2+5);
blockB[count+6] = rhs(k,j2+6);
blockB[count+7] = rhs(k,j2+7);
}
count += nr; count += nr;
} }
} }
// second part: diagonal block // second part: diagonal block
for(Index j2=k2; j2<(std::min)(k2+rows,packet_cols); j2+=nr) Index end8 = nr>=8 ? (std::min)(k2+rows,packet_cols8) : k2;
if(nr>=8)
{ {
// again we can split vertically in three different parts (transpose, symmetric, normal) for(Index j2=k2; j2<end8; j2+=8)
// transpose
for(Index k=k2; k<j2; k++)
{ {
blockB[count+0] = conj(rhs(j2+0,k)); // again we can split vertically in three different parts (transpose, symmetric, normal)
blockB[count+1] = conj(rhs(j2+1,k)); // transpose
if (nr==4) for(Index k=k2; k<j2; k++)
{ {
blockB[count+2] = conj(rhs(j2+2,k)); blockB[count+0] = numext::conj(rhs(j2+0,k));
blockB[count+3] = conj(rhs(j2+3,k)); blockB[count+1] = numext::conj(rhs(j2+1,k));
blockB[count+2] = numext::conj(rhs(j2+2,k));
blockB[count+3] = numext::conj(rhs(j2+3,k));
blockB[count+4] = numext::conj(rhs(j2+4,k));
blockB[count+5] = numext::conj(rhs(j2+5,k));
blockB[count+6] = numext::conj(rhs(j2+6,k));
blockB[count+7] = numext::conj(rhs(j2+7,k));
count += 8;
} }
count += nr; // symmetric
} Index h = 0;
// symmetric for(Index k=j2; k<j2+8; k++)
Index h = 0; {
for(Index k=j2; k<j2+nr; k++) // normal
{ for (Index w=0 ; w<h; ++w)
// normal blockB[count+w] = rhs(k,j2+w);
for (Index w=0 ; w<h; ++w)
blockB[count+w] = rhs(k,j2+w);
blockB[count+h] = real(rhs(k,k)); blockB[count+h] = numext::real(rhs(k,k));
// transpose // transpose
for (Index w=h+1 ; w<nr; ++w) for (Index w=h+1 ; w<8; ++w)
blockB[count+w] = conj(rhs(j2+w,k)); blockB[count+w] = numext::conj(rhs(j2+w,k));
count += nr; count += 8;
++h; ++h;
}
// normal
for(Index k=j2+8; k<end_k; k++)
{
blockB[count+0] = rhs(k,j2+0);
blockB[count+1] = rhs(k,j2+1);
blockB[count+2] = rhs(k,j2+2);
blockB[count+3] = rhs(k,j2+3);
blockB[count+4] = rhs(k,j2+4);
blockB[count+5] = rhs(k,j2+5);
blockB[count+6] = rhs(k,j2+6);
blockB[count+7] = rhs(k,j2+7);
count += 8;
}
} }
// normal }
for(Index k=j2+nr; k<end_k; k++) if(nr>=4)
{
for(Index j2=end8; j2<(std::min)(k2+rows,packet_cols4); j2+=4)
{ {
blockB[count+0] = rhs(k,j2+0); // again we can split vertically in three different parts (transpose, symmetric, normal)
blockB[count+1] = rhs(k,j2+1); // transpose
if (nr==4) for(Index k=k2; k<j2; k++)
{ {
blockB[count+0] = numext::conj(rhs(j2+0,k));
blockB[count+1] = numext::conj(rhs(j2+1,k));
blockB[count+2] = numext::conj(rhs(j2+2,k));
blockB[count+3] = numext::conj(rhs(j2+3,k));
count += 4;
}
// symmetric
Index h = 0;
for(Index k=j2; k<j2+4; k++)
{
// normal
for (Index w=0 ; w<h; ++w)
blockB[count+w] = rhs(k,j2+w);
blockB[count+h] = numext::real(rhs(k,k));
// transpose
for (Index w=h+1 ; w<4; ++w)
blockB[count+w] = numext::conj(rhs(j2+w,k));
count += 4;
++h;
}
// normal
for(Index k=j2+4; k<end_k; k++)
{
blockB[count+0] = rhs(k,j2+0);
blockB[count+1] = rhs(k,j2+1);
blockB[count+2] = rhs(k,j2+2); blockB[count+2] = rhs(k,j2+2);
blockB[count+3] = rhs(k,j2+3); blockB[count+3] = rhs(k,j2+3);
count += 4;
} }
count += nr;
} }
} }
// third part: transposed // third part: transposed
for(Index j2=k2+rows; j2<packet_cols; j2+=nr) if(nr>=8)
{ {
for(Index k=k2; k<end_k; k++) for(Index j2=k2+rows; j2<packet_cols8; j2+=8)
{ {
blockB[count+0] = conj(rhs(j2+0,k)); for(Index k=k2; k<end_k; k++)
blockB[count+1] = conj(rhs(j2+1,k));
if (nr==4)
{ {
blockB[count+2] = conj(rhs(j2+2,k)); blockB[count+0] = numext::conj(rhs(j2+0,k));
blockB[count+3] = conj(rhs(j2+3,k)); blockB[count+1] = numext::conj(rhs(j2+1,k));
blockB[count+2] = numext::conj(rhs(j2+2,k));
blockB[count+3] = numext::conj(rhs(j2+3,k));
blockB[count+4] = numext::conj(rhs(j2+4,k));
blockB[count+5] = numext::conj(rhs(j2+5,k));
blockB[count+6] = numext::conj(rhs(j2+6,k));
blockB[count+7] = numext::conj(rhs(j2+7,k));
count += 8;
}
}
}
if(nr>=4)
{
for(Index j2=(std::max)(packet_cols8,k2+rows); j2<packet_cols4; j2+=4)
{
for(Index k=k2; k<end_k; k++)
{
blockB[count+0] = numext::conj(rhs(j2+0,k));
blockB[count+1] = numext::conj(rhs(j2+1,k));
blockB[count+2] = numext::conj(rhs(j2+2,k));
blockB[count+3] = numext::conj(rhs(j2+3,k));
count += 4;
} }
count += nr;
} }
} }
// copy the remaining columns one at a time (=> the same with nr==1) // copy the remaining columns one at a time (=> the same with nr==1)
for(Index j2=packet_cols; j2<cols; ++j2) for(Index j2=packet_cols4; j2<cols; ++j2)
{ {
// transpose // transpose
Index half = (std::min)(end_k,j2); Index half = (std::min)(end_k,j2);
for(Index k=k2; k<half; k++) for(Index k=k2; k<half; k++)
{ {
blockB[count] = conj(rhs(j2,k)); blockB[count] = numext::conj(rhs(j2,k));
count += 1; count += 1;
} }
if(half==j2 && half<k2+rows) if(half==j2 && half<k2+rows)
{ {
blockB[count] = real(rhs(j2,j2)); blockB[count] = numext::real(rhs(j2,j2));
count += 1; count += 1;
} }
else else
...@@ -197,69 +277,85 @@ struct symm_pack_rhs ...@@ -197,69 +277,85 @@ struct symm_pack_rhs
template <typename Scalar, typename Index, template <typename Scalar, typename Index,
int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs, int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs, int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs,
int ResStorageOrder> int ResStorageOrder, int ResInnerStride>
struct product_selfadjoint_matrix; struct product_selfadjoint_matrix;
template <typename Scalar, typename Index, template <typename Scalar, typename Index,
int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs, int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs> int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs,
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,RowMajor> int ResInnerStride>
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,RowMajor,ResInnerStride>
{ {
static EIGEN_STRONG_INLINE void run( static EIGEN_STRONG_INLINE void run(
Index rows, Index cols, Index rows, Index cols,
const Scalar* lhs, Index lhsStride, const Scalar* lhs, Index lhsStride,
const Scalar* rhs, Index rhsStride, const Scalar* rhs, Index rhsStride,
Scalar* res, Index resStride, Scalar* res, Index resIncr, Index resStride,
Scalar alpha) const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{ {
product_selfadjoint_matrix<Scalar, Index, product_selfadjoint_matrix<Scalar, Index,
EIGEN_LOGICAL_XOR(RhsSelfAdjoint,RhsStorageOrder==RowMajor) ? ColMajor : RowMajor, EIGEN_LOGICAL_XOR(RhsSelfAdjoint,RhsStorageOrder==RowMajor) ? ColMajor : RowMajor,
RhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsSelfAdjoint,ConjugateRhs), RhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsSelfAdjoint,ConjugateRhs),
EIGEN_LOGICAL_XOR(LhsSelfAdjoint,LhsStorageOrder==RowMajor) ? ColMajor : RowMajor, EIGEN_LOGICAL_XOR(LhsSelfAdjoint,LhsStorageOrder==RowMajor) ? ColMajor : RowMajor,
LhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsSelfAdjoint,ConjugateLhs), LhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsSelfAdjoint,ConjugateLhs),
ColMajor> ColMajor,ResInnerStride>
::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha); ::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resIncr, resStride, alpha, blocking);
} }
}; };
template <typename Scalar, typename Index, template <typename Scalar, typename Index,
int LhsStorageOrder, bool ConjugateLhs, int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs> int RhsStorageOrder, bool ConjugateRhs,
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor> int ResInnerStride>
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>
{ {
static EIGEN_DONT_INLINE void run( static EIGEN_DONT_INLINE void run(
Index rows, Index cols, Index rows, Index cols,
const Scalar* _lhs, Index lhsStride, const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride, const Scalar* _rhs, Index rhsStride,
Scalar* res, Index resStride, Scalar* res, Index resIncr, Index resStride,
Scalar alpha) const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index,
int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs,
int ResInnerStride>
EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>::run(
Index rows, Index cols,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
Scalar* _res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{ {
Index size = rows; Index size = rows;
const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar> Traits;
Index kc = size; // cache block size along the K direction typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
Index mc = rows; // cache block size along the M direction typedef const_blas_data_mapper<Scalar, Index, (LhsStorageOrder == RowMajor) ? ColMajor : RowMajor> LhsTransposeMapper;
Index nc = cols; // cache block size along the N direction typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc); typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
// kc must smaller than mc LhsMapper lhs(_lhs,lhsStride);
LhsTransposeMapper lhs_transpose(_lhs,lhsStride);
RhsMapper rhs(_rhs,rhsStride);
ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
// kc must be smaller than mc
kc = (std::min)(kc,mc); kc = (std::min)(kc,mc);
std::size_t sizeA = kc*mc;
std::size_t sizeB = kc*cols;
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
std::size_t sizeW = kc*Traits::WorkSpaceFactor; gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
std::size_t sizeB = sizeW + kc*cols;
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0);
ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0);
Scalar* blockB = allocatedBlockB + sizeW;
gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
symm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; symm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs; gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder> pack_rhs;
gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder==RowMajor?ColMajor:RowMajor, true> pack_lhs_transposed; gemm_pack_lhs<Scalar, Index, LhsTransposeMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder==RowMajor?ColMajor:RowMajor, true> pack_lhs_transposed;
for(Index k2=0; k2<size; k2+=kc) for(Index k2=0; k2<size; k2+=kc)
{ {
...@@ -268,7 +364,7 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs ...@@ -268,7 +364,7 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs
// we have selected one row panel of rhs and one column panel of lhs // we have selected one row panel of rhs and one column panel of lhs
// pack rhs's panel into a sequential chunk of memory // pack rhs's panel into a sequential chunk of memory
// and expand each coeff to a constant packet for further reuse // and expand each coeff to a constant packet for further reuse
pack_rhs(blockB, &rhs(k2,0), rhsStride, actual_kc, cols); pack_rhs(blockB, rhs.getSubMapper(k2,0), actual_kc, cols);
// the select lhs's panel has to be split in three different parts: // the select lhs's panel has to be split in three different parts:
// 1 - the transposed panel above the diagonal block => transposed packed copy // 1 - the transposed panel above the diagonal block => transposed packed copy
...@@ -278,9 +374,9 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs ...@@ -278,9 +374,9 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs
{ {
const Index actual_mc = (std::min)(i2+mc,k2)-i2; const Index actual_mc = (std::min)(i2+mc,k2)-i2;
// transposed packed copy // transposed packed copy
pack_lhs_transposed(blockA, &lhs(k2, i2), lhsStride, actual_kc, actual_mc); pack_lhs_transposed(blockA, lhs_transpose.getSubMapper(i2, k2), actual_kc, actual_mc);
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha); gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
} }
// the block diagonal // the block diagonal
{ {
...@@ -288,53 +384,65 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs ...@@ -288,53 +384,65 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs
// symmetric packed copy // symmetric packed copy
pack_lhs(blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc); pack_lhs(blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res+k2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha); gebp_kernel(res.getSubMapper(k2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
} }
for(Index i2=k2+kc; i2<size; i2+=mc) for(Index i2=k2+kc; i2<size; i2+=mc)
{ {
const Index actual_mc = (std::min)(i2+mc,size)-i2; const Index actual_mc = (std::min)(i2+mc,size)-i2;
gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder,false>() gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder,false>()
(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc); (blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha); gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
} }
} }
} }
};
// matrix * selfadjoint product // matrix * selfadjoint product
template <typename Scalar, typename Index, template <typename Scalar, typename Index,
int LhsStorageOrder, bool ConjugateLhs, int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs> int RhsStorageOrder, bool ConjugateRhs,
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor> int ResInnerStride>
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>
{ {
static EIGEN_DONT_INLINE void run( static EIGEN_DONT_INLINE void run(
Index rows, Index cols, Index rows, Index cols,
const Scalar* _lhs, Index lhsStride, const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride, const Scalar* _rhs, Index rhsStride,
Scalar* res, Index resStride, Scalar* res, Index resIncr, Index resStride,
Scalar alpha) const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index,
int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs,
int ResInnerStride>
EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>::run(
Index rows, Index cols,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
Scalar* _res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{ {
Index size = cols; Index size = cols;
const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar> Traits;
Index kc = size; // cache block size along the K direction typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
Index mc = rows; // cache block size along the M direction typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
Index nc = cols; // cache block size along the N direction LhsMapper lhs(_lhs,lhsStride);
computeProductBlockingSizes<Scalar,Scalar>(kc, mc, nc); ResMapper res(_res,resStride, resIncr);
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
std::size_t sizeB = sizeW + kc*cols; Index kc = blocking.kc(); // cache block size along the K direction
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, kc*mc, 0); Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
ei_declare_aligned_stack_constructed_variable(Scalar, allocatedBlockB, sizeB, 0); std::size_t sizeA = kc*mc;
Scalar* blockB = allocatedBlockB + sizeW; std::size_t sizeB = kc*cols;
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
symm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs; symm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
for(Index k2=0; k2<size; k2+=kc) for(Index k2=0; k2<size; k2+=kc)
...@@ -347,13 +455,12 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLh ...@@ -347,13 +455,12 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLh
for(Index i2=0; i2<rows; i2+=mc) for(Index i2=0; i2<rows; i2+=mc)
{ {
const Index actual_mc = (std::min)(i2+mc,rows)-i2; const Index actual_mc = (std::min)(i2+mc,rows)-i2;
pack_lhs(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc); pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha); gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
} }
} }
} }
};
} // end namespace internal } // end namespace internal
...@@ -362,55 +469,59 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLh ...@@ -362,55 +469,59 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLh
***************************************************************************/ ***************************************************************************/
namespace internal { namespace internal {
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode> template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
struct traits<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false> > struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,RhsMode,false>
: traits<ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>, Lhs, Rhs> >
{};
}
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode>
struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>
: public ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,RhsMode,false>, Lhs, Rhs >
{ {
EIGEN_PRODUCT_PUBLIC_INTERFACE(SelfadjointProductMatrix) typedef typename Product<Lhs,Rhs>::Scalar Scalar;
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
enum { enum {
LhsIsUpper = (LhsMode&(Upper|Lower))==Upper, LhsIsUpper = (LhsMode&(Upper|Lower))==Upper,
LhsIsSelfAdjoint = (LhsMode&SelfAdjoint)==SelfAdjoint, LhsIsSelfAdjoint = (LhsMode&SelfAdjoint)==SelfAdjoint,
RhsIsUpper = (RhsMode&(Upper|Lower))==Upper, RhsIsUpper = (RhsMode&(Upper|Lower))==Upper,
RhsIsSelfAdjoint = (RhsMode&SelfAdjoint)==SelfAdjoint RhsIsSelfAdjoint = (RhsMode&SelfAdjoint)==SelfAdjoint
}; };
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const template<typename Dest>
static void run(Dest &dst, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
{ {
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs); typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs); typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs); * RhsBlasTraits::extractScalarFactor(a_rhs);
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,1> BlockingType;
BlockingType blocking(lhs.rows(), rhs.cols(), lhs.cols(), 1, false);
internal::product_selfadjoint_matrix<Scalar, Index, internal::product_selfadjoint_matrix<Scalar, Index,
EIGEN_LOGICAL_XOR(LhsIsUpper, EIGEN_LOGICAL_XOR(LhsIsUpper,internal::traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint,
internal::traits<Lhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, LhsIsSelfAdjoint,
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsIsUpper,bool(LhsBlasTraits::NeedToConjugate)), NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsIsUpper,bool(LhsBlasTraits::NeedToConjugate)),
EIGEN_LOGICAL_XOR(RhsIsUpper, EIGEN_LOGICAL_XOR(RhsIsUpper,internal::traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
internal::traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsIsUpper,bool(RhsBlasTraits::NeedToConjugate)), NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsIsUpper,bool(RhsBlasTraits::NeedToConjugate)),
internal::traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor> internal::traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor,
Dest::InnerStrideAtCompileTime>
::run( ::run(
lhs.rows(), rhs.cols(), // sizes lhs.rows(), rhs.cols(), // sizes
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
&rhs.coeffRef(0,0), rhs.outerStride(), // rhs info &rhs.coeffRef(0,0), rhs.outerStride(), // rhs info
&dst.coeffRef(0,0), dst.outerStride(), // result info &dst.coeffRef(0,0), dst.innerStride(), dst.outerStride(), // result info
actualAlpha // alpha actualAlpha, blocking // alpha
); );
} }
}; };
} // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_H #endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_H
/*
Copyright (c) 2011, Intel Corporation. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Intel Corporation nor the names of its contributors may
be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
********************************************************************************
* Content : Eigen bindings to BLAS F77
* Self adjoint matrix * matrix product functionality based on ?SYMM/?HEMM.
********************************************************************************
*/
#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
#define EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
namespace Eigen {
namespace internal {
/* Optimized selfadjoint matrix * matrix (?SYMM/?HEMM) product */
#define EIGEN_BLAS_SYMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor,1> \
{\
\
static void run( \
Index rows, Index cols, \
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resIncr, Index resStride, \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
eigen_assert(resIncr == 1); \
char side='L', uplo='L'; \
BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
\
/* Set transpose options */ \
/* Set m, n, k */ \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(cols); \
\
/* Set lda, ldb, ldc */ \
lda = convert_index<BlasIndex>(lhsStride); \
ldb = convert_index<BlasIndex>(rhsStride); \
ldc = convert_index<BlasIndex>(resStride); \
\
/* Set a, b, c */ \
if (LhsStorageOrder==RowMajor) uplo='U'; \
a = _lhs; \
\
if (RhsStorageOrder==RowMajor) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = rhs.adjoint(); \
b = b_tmp.data(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _rhs; \
\
BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\
} \
};
#define EIGEN_BLAS_HEMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor,1> \
{\
static void run( \
Index rows, Index cols, \
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resIncr, Index resStride, \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
eigen_assert(resIncr == 1); \
char side='L', uplo='L'; \
BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> a_tmp; \
\
/* Set transpose options */ \
/* Set m, n, k */ \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(cols); \
\
/* Set lda, ldb, ldc */ \
lda = convert_index<BlasIndex>(lhsStride); \
ldb = convert_index<BlasIndex>(rhsStride); \
ldc = convert_index<BlasIndex>(resStride); \
\
/* Set a, b, c */ \
if (((LhsStorageOrder==ColMajor) && ConjugateLhs) || ((LhsStorageOrder==RowMajor) && (!ConjugateLhs))) { \
Map<const Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder>, 0, OuterStride<> > lhs(_lhs,m,m,OuterStride<>(lhsStride)); \
a_tmp = lhs.conjugate(); \
a = a_tmp.data(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else a = _lhs; \
if (LhsStorageOrder==RowMajor) uplo='U'; \
\
if (RhsStorageOrder==ColMajor && (!ConjugateRhs)) { \
b = _rhs; } \
else { \
if (RhsStorageOrder==ColMajor && ConjugateRhs) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,m,n,OuterStride<>(rhsStride)); \
b_tmp = rhs.conjugate(); \
} else \
if (ConjugateRhs) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = rhs.adjoint(); \
} else { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(_rhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = rhs.transpose(); \
} \
b = b_tmp.data(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} \
\
BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\
} \
};
#ifdef EIGEN_USE_MKL
EIGEN_BLAS_SYMM_L(double, double, d, dsymm)
EIGEN_BLAS_SYMM_L(float, float, f, ssymm)
EIGEN_BLAS_HEMM_L(dcomplex, MKL_Complex16, cd, zhemm)
EIGEN_BLAS_HEMM_L(scomplex, MKL_Complex8, cf, chemm)
#else
EIGEN_BLAS_SYMM_L(double, double, d, dsymm_)
EIGEN_BLAS_SYMM_L(float, float, f, ssymm_)
EIGEN_BLAS_HEMM_L(dcomplex, double, cd, zhemm_)
EIGEN_BLAS_HEMM_L(scomplex, float, cf, chemm_)
#endif
/* Optimized matrix * selfadjoint matrix (?SYMM/?HEMM) product */
#define EIGEN_BLAS_SYMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor,1> \
{\
\
static void run( \
Index rows, Index cols, \
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resIncr, Index resStride, \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
eigen_assert(resIncr == 1); \
char side='R', uplo='L'; \
BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
\
/* Set m, n, k */ \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(cols); \
\
/* Set lda, ldb, ldc */ \
lda = convert_index<BlasIndex>(rhsStride); \
ldb = convert_index<BlasIndex>(lhsStride); \
ldc = convert_index<BlasIndex>(resStride); \
\
/* Set a, b, c */ \
if (RhsStorageOrder==RowMajor) uplo='U'; \
a = _rhs; \
\
if (LhsStorageOrder==RowMajor) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(rhsStride)); \
b_tmp = lhs.adjoint(); \
b = b_tmp.data(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} else b = _lhs; \
\
BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
\
} \
};
#define EIGEN_BLAS_HEMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor,1> \
{\
static void run( \
Index rows, Index cols, \
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resIncr, Index resStride, \
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
{ \
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
eigen_assert(resIncr == 1); \
char side='R', uplo='L'; \
BlasIndex m, n, lda, ldb, ldc; \
const EIGTYPE *a, *b; \
EIGTYPE beta(1); \
MatrixX##EIGPREFIX b_tmp; \
Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> a_tmp; \
\
/* Set m, n, k */ \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(cols); \
\
/* Set lda, ldb, ldc */ \
lda = convert_index<BlasIndex>(rhsStride); \
ldb = convert_index<BlasIndex>(lhsStride); \
ldc = convert_index<BlasIndex>(resStride); \
\
/* Set a, b, c */ \
if (((RhsStorageOrder==ColMajor) && ConjugateRhs) || ((RhsStorageOrder==RowMajor) && (!ConjugateRhs))) { \
Map<const Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder>, 0, OuterStride<> > rhs(_rhs,n,n,OuterStride<>(rhsStride)); \
a_tmp = rhs.conjugate(); \
a = a_tmp.data(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else a = _rhs; \
if (RhsStorageOrder==RowMajor) uplo='U'; \
\
if (LhsStorageOrder==ColMajor && (!ConjugateLhs)) { \
b = _lhs; } \
else { \
if (LhsStorageOrder==ColMajor && ConjugateLhs) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,m,n,OuterStride<>(lhsStride)); \
b_tmp = lhs.conjugate(); \
} else \
if (ConjugateLhs) { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(lhsStride)); \
b_tmp = lhs.adjoint(); \
} else { \
Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(_lhs,n,m,OuterStride<>(lhsStride)); \
b_tmp = lhs.transpose(); \
} \
b = b_tmp.data(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
} \
\
BLASFUNC(&side, &uplo, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
} \
};
#ifdef EIGEN_USE_MKL
EIGEN_BLAS_SYMM_R(double, double, d, dsymm)
EIGEN_BLAS_SYMM_R(float, float, f, ssymm)
EIGEN_BLAS_HEMM_R(dcomplex, MKL_Complex16, cd, zhemm)
EIGEN_BLAS_HEMM_R(scomplex, MKL_Complex8, cf, chemm)
#else
EIGEN_BLAS_SYMM_R(double, double, d, dsymm_)
EIGEN_BLAS_SYMM_R(float, float, f, ssymm_)
EIGEN_BLAS_HEMM_R(dcomplex, double, cd, zhemm_)
EIGEN_BLAS_HEMM_R(scomplex, float, cf, chemm_)
#endif
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_BLAS_H
...@@ -30,7 +30,16 @@ struct selfadjoint_matrix_vector_product ...@@ -30,7 +30,16 @@ struct selfadjoint_matrix_vector_product
static EIGEN_DONT_INLINE void run( static EIGEN_DONT_INLINE void run(
Index size, Index size,
const Scalar* lhs, Index lhsStride, const Scalar* lhs, Index lhsStride,
const Scalar* _rhs, Index rhsIncr, const Scalar* rhs,
Scalar* res,
Scalar alpha);
};
template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs, int Version>
EIGEN_DONT_INLINE void selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Version>::run(
Index size,
const Scalar* lhs, Index lhsStride,
const Scalar* rhs,
Scalar* res, Scalar* res,
Scalar alpha) Scalar alpha)
{ {
...@@ -46,23 +55,13 @@ static EIGEN_DONT_INLINE void run( ...@@ -46,23 +55,13 @@ static EIGEN_DONT_INLINE void run(
conj_helper<Scalar,Scalar,NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, IsRowMajor), ConjugateRhs> cj0; conj_helper<Scalar,Scalar,NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, IsRowMajor), ConjugateRhs> cj0;
conj_helper<Scalar,Scalar,NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, !IsRowMajor), ConjugateRhs> cj1; conj_helper<Scalar,Scalar,NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, !IsRowMajor), ConjugateRhs> cj1;
conj_helper<Scalar,Scalar,NumTraits<Scalar>::IsComplex, ConjugateRhs> cjd; conj_helper<RealScalar,Scalar,false, ConjugateRhs> cjd;
conj_helper<Packet,Packet,NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, IsRowMajor), ConjugateRhs> pcj0; conj_helper<Packet,Packet,NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, IsRowMajor), ConjugateRhs> pcj0;
conj_helper<Packet,Packet,NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, !IsRowMajor), ConjugateRhs> pcj1; conj_helper<Packet,Packet,NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, !IsRowMajor), ConjugateRhs> pcj1;
Scalar cjAlpha = ConjugateRhs ? conj(alpha) : alpha; Scalar cjAlpha = ConjugateRhs ? numext::conj(alpha) : alpha;
// FIXME this copy is now handled outside product_selfadjoint_vector, so it could probably be removed.
// if the rhs is not sequentially stored in memory we copy it to a temporary buffer,
// this is because we need to extract packets
ei_declare_aligned_stack_constructed_variable(Scalar,rhs,size,rhsIncr==1 ? const_cast<Scalar*>(_rhs) : 0);
if (rhsIncr!=1)
{
const Scalar* it = _rhs;
for (Index i=0; i<size; ++i, it+=rhsIncr)
rhs[i] = *it;
}
Index bound = (std::max)(Index(0),size-8) & 0xfffffffe; Index bound = (std::max)(Index(0),size-8) & 0xfffffffe;
if (FirstTriangular) if (FirstTriangular)
...@@ -71,8 +70,8 @@ static EIGEN_DONT_INLINE void run( ...@@ -71,8 +70,8 @@ static EIGEN_DONT_INLINE void run(
for (Index j=FirstTriangular ? bound : 0; for (Index j=FirstTriangular ? bound : 0;
j<(FirstTriangular ? size : bound);j+=2) j<(FirstTriangular ? size : bound);j+=2)
{ {
register const Scalar* EIGEN_RESTRICT A0 = lhs + j*lhsStride; const Scalar* EIGEN_RESTRICT A0 = lhs + j*lhsStride;
register const Scalar* EIGEN_RESTRICT A1 = lhs + (j+1)*lhsStride; const Scalar* EIGEN_RESTRICT A1 = lhs + (j+1)*lhsStride;
Scalar t0 = cjAlpha * rhs[j]; Scalar t0 = cjAlpha * rhs[j];
Packet ptmp0 = pset1<Packet>(t0); Packet ptmp0 = pset1<Packet>(t0);
...@@ -84,14 +83,13 @@ static EIGEN_DONT_INLINE void run( ...@@ -84,14 +83,13 @@ static EIGEN_DONT_INLINE void run(
Scalar t3(0); Scalar t3(0);
Packet ptmp3 = pset1<Packet>(t3); Packet ptmp3 = pset1<Packet>(t3);
size_t starti = FirstTriangular ? 0 : j+2; Index starti = FirstTriangular ? 0 : j+2;
size_t endi = FirstTriangular ? j : size; Index endi = FirstTriangular ? j : size;
size_t alignedStart = (starti) + internal::first_aligned(&res[starti], endi-starti); Index alignedStart = (starti) + internal::first_default_aligned(&res[starti], endi-starti);
size_t alignedEnd = alignedStart + ((endi-alignedStart)/(PacketSize))*(PacketSize); Index alignedEnd = alignedStart + ((endi-alignedStart)/(PacketSize))*(PacketSize);
// TODO make sure this product is a real * complex and that the rhs is properly conjugated if needed res[j] += cjd.pmul(numext::real(A0[j]), t0);
res[j] += cjd.pmul(internal::real(A0[j]), t0); res[j+1] += cjd.pmul(numext::real(A1[j+1]), t1);
res[j+1] += cjd.pmul(internal::real(A1[j+1]), t1);
if(FirstTriangular) if(FirstTriangular)
{ {
res[j] += cj0.pmul(A1[j], t1); res[j] += cj0.pmul(A1[j], t1);
...@@ -103,11 +101,11 @@ static EIGEN_DONT_INLINE void run( ...@@ -103,11 +101,11 @@ static EIGEN_DONT_INLINE void run(
t2 += cj1.pmul(A0[j+1], rhs[j+1]); t2 += cj1.pmul(A0[j+1], rhs[j+1]);
} }
for (size_t i=starti; i<alignedStart; ++i) for (Index i=starti; i<alignedStart; ++i)
{ {
res[i] += t0 * A0[i] + t1 * A1[i]; res[i] += cj0.pmul(A0[i], t0) + cj0.pmul(A1[i],t1);
t2 += conj(A0[i]) * rhs[i]; t2 += cj1.pmul(A0[i], rhs[i]);
t3 += conj(A1[i]) * rhs[i]; t3 += cj1.pmul(A1[i], rhs[i]);
} }
// Yes this an optimization for gcc 4.3 and 4.4 (=> huge speed up) // Yes this an optimization for gcc 4.3 and 4.4 (=> huge speed up)
// gcc 4.2 does this optimization automatically. // gcc 4.2 does this optimization automatically.
...@@ -115,7 +113,7 @@ static EIGEN_DONT_INLINE void run( ...@@ -115,7 +113,7 @@ static EIGEN_DONT_INLINE void run(
const Scalar* EIGEN_RESTRICT a1It = A1 + alignedStart; const Scalar* EIGEN_RESTRICT a1It = A1 + alignedStart;
const Scalar* EIGEN_RESTRICT rhsIt = rhs + alignedStart; const Scalar* EIGEN_RESTRICT rhsIt = rhs + alignedStart;
Scalar* EIGEN_RESTRICT resIt = res + alignedStart; Scalar* EIGEN_RESTRICT resIt = res + alignedStart;
for (size_t i=alignedStart; i<alignedEnd; i+=PacketSize) for (Index i=alignedStart; i<alignedEnd; i+=PacketSize)
{ {
Packet A0i = ploadu<Packet>(a0It); a0It += PacketSize; Packet A0i = ploadu<Packet>(a0It); a0It += PacketSize;
Packet A1i = ploadu<Packet>(a1It); a1It += PacketSize; Packet A1i = ploadu<Packet>(a1It); a1It += PacketSize;
...@@ -127,7 +125,7 @@ static EIGEN_DONT_INLINE void run( ...@@ -127,7 +125,7 @@ static EIGEN_DONT_INLINE void run(
ptmp3 = pcj1.pmadd(A1i, Bi, ptmp3); ptmp3 = pcj1.pmadd(A1i, Bi, ptmp3);
pstore(resIt,Xi); resIt += PacketSize; pstore(resIt,Xi); resIt += PacketSize;
} }
for (size_t i=alignedEnd; i<endi; i++) for (Index i=alignedEnd; i<endi; i++)
{ {
res[i] += cj0.pmul(A0[i], t0) + cj0.pmul(A1[i],t1); res[i] += cj0.pmul(A0[i], t0) + cj0.pmul(A1[i],t1);
t2 += cj1.pmul(A0[i], rhs[i]); t2 += cj1.pmul(A0[i], rhs[i]);
...@@ -139,12 +137,11 @@ static EIGEN_DONT_INLINE void run( ...@@ -139,12 +137,11 @@ static EIGEN_DONT_INLINE void run(
} }
for (Index j=FirstTriangular ? 0 : bound;j<(FirstTriangular ? bound : size);j++) for (Index j=FirstTriangular ? 0 : bound;j<(FirstTriangular ? bound : size);j++)
{ {
register const Scalar* EIGEN_RESTRICT A0 = lhs + j*lhsStride; const Scalar* EIGEN_RESTRICT A0 = lhs + j*lhsStride;
Scalar t1 = cjAlpha * rhs[j]; Scalar t1 = cjAlpha * rhs[j];
Scalar t2(0); Scalar t2(0);
// TODO make sure this product is a real * complex and that the rhs is properly conjugated if needed res[j] += cjd.pmul(numext::real(A0[j]), t1);
res[j] += cjd.pmul(internal::real(A0[j]), t1);
for (Index i=FirstTriangular ? 0 : j+1; i<(FirstTriangular ? j : size); i++) for (Index i=FirstTriangular ? 0 : j+1; i<(FirstTriangular ? j : size); i++)
{ {
res[i] += cj0.pmul(A0[i], t1); res[i] += cj0.pmul(A0[i], t1);
...@@ -153,7 +150,6 @@ static EIGEN_DONT_INLINE void run( ...@@ -153,7 +150,6 @@ static EIGEN_DONT_INLINE void run(
res[j] += alpha * t2; res[j] += alpha * t2;
} }
} }
};
} // end namespace internal } // end namespace internal
...@@ -162,45 +158,44 @@ static EIGEN_DONT_INLINE void run( ...@@ -162,45 +158,44 @@ static EIGEN_DONT_INLINE void run(
***************************************************************************/ ***************************************************************************/
namespace internal { namespace internal {
template<typename Lhs, int LhsMode, typename Rhs>
struct traits<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> >
: traits<ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>, Lhs, Rhs> >
{};
}
template<typename Lhs, int LhsMode, typename Rhs> template<typename Lhs, int LhsMode, typename Rhs>
struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,0,true>
: public ProductBase<SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>, Lhs, Rhs >
{ {
EIGEN_PRODUCT_PUBLIC_INTERFACE(SelfadjointProductMatrix) typedef typename Product<Lhs,Rhs>::Scalar Scalar;
enum { typedef internal::blas_traits<Lhs> LhsBlasTraits;
LhsUpLo = LhsMode&(Upper|Lower) typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
}; typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {} typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
enum { LhsUpLo = LhsMode&(Upper|Lower) };
template<typename Dest>
static void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
{ {
typedef typename Dest::Scalar ResScalar; typedef typename Dest::Scalar ResScalar;
typedef typename Base::RhsScalar RhsScalar; typedef typename Rhs::Scalar RhsScalar;
typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest; typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest;
eigen_assert(dest.rows()==m_lhs.rows() && dest.cols()==m_rhs.cols()); eigen_assert(dest.rows()==a_lhs.rows() && dest.cols()==a_rhs.cols());
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs); typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs); typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs); * RhsBlasTraits::extractScalarFactor(a_rhs);
enum { enum {
EvalToDest = (Dest::InnerStrideAtCompileTime==1), EvalToDest = (Dest::InnerStrideAtCompileTime==1),
UseRhs = (_ActualRhsType::InnerStrideAtCompileTime==1) UseRhs = (ActualRhsTypeCleaned::InnerStrideAtCompileTime==1)
}; };
internal::gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,!EvalToDest> static_dest; internal::gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,!EvalToDest> static_dest;
internal::gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!UseRhs> static_rhs; internal::gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!UseRhs> static_rhs;
ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
EvalToDest ? dest.data() : static_dest.data()); EvalToDest ? dest.data() : static_dest.data());
...@@ -211,7 +206,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> ...@@ -211,7 +206,7 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
if(!EvalToDest) if(!EvalToDest)
{ {
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
int size = dest.size(); Index size = dest.size();
EIGEN_DENSE_STORAGE_CTOR_PLUGIN EIGEN_DENSE_STORAGE_CTOR_PLUGIN
#endif #endif
MappedDest(actualDestPtr, dest.size()) = dest; MappedDest(actualDestPtr, dest.size()) = dest;
...@@ -220,18 +215,19 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> ...@@ -220,18 +215,19 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
if(!UseRhs) if(!UseRhs)
{ {
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
int size = rhs.size(); Index size = rhs.size();
EIGEN_DENSE_STORAGE_CTOR_PLUGIN EIGEN_DENSE_STORAGE_CTOR_PLUGIN
#endif #endif
Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, rhs.size()) = rhs; Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, rhs.size()) = rhs;
} }
internal::selfadjoint_matrix_vector_product<Scalar, Index, (internal::traits<_ActualLhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, int(LhsUpLo), bool(LhsBlasTraits::NeedToConjugate), bool(RhsBlasTraits::NeedToConjugate)>::run internal::selfadjoint_matrix_vector_product<Scalar, Index, (internal::traits<ActualLhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor,
int(LhsUpLo), bool(LhsBlasTraits::NeedToConjugate), bool(RhsBlasTraits::NeedToConjugate)>::run
( (
lhs.rows(), // size lhs.rows(), // size
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
actualRhsPtr, 1, // rhs info actualRhsPtr, // rhs info
actualDestPtr, // result info actualDestPtr, // result info
actualAlpha // scale factor actualAlpha // scale factor
); );
...@@ -241,34 +237,24 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true> ...@@ -241,34 +237,24 @@ struct SelfadjointProductMatrix<Lhs,LhsMode,false,Rhs,0,true>
} }
}; };
namespace internal {
template<typename Lhs, typename Rhs, int RhsMode>
struct traits<SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false> >
: traits<ProductBase<SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false>, Lhs, Rhs> >
{};
}
template<typename Lhs, typename Rhs, int RhsMode> template<typename Lhs, typename Rhs, int RhsMode>
struct SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false> struct selfadjoint_product_impl<Lhs,0,true,Rhs,RhsMode,false>
: public ProductBase<SelfadjointProductMatrix<Lhs,0,true,Rhs,RhsMode,false>, Lhs, Rhs >
{ {
EIGEN_PRODUCT_PUBLIC_INTERFACE(SelfadjointProductMatrix) typedef typename Product<Lhs,Rhs>::Scalar Scalar;
enum { RhsUpLo = RhsMode&(Upper|Lower) };
enum { template<typename Dest>
RhsUpLo = RhsMode&(Upper|Lower) static void run(Dest& dest, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar& alpha)
};
SelfadjointProductMatrix(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
{ {
// let's simply transpose the product // let's simply transpose the product
Transpose<Dest> destT(dest); Transpose<Dest> destT(dest);
SelfadjointProductMatrix<Transpose<const Rhs>, int(RhsUpLo)==Upper ? Lower : Upper, false, selfadjoint_product_impl<Transpose<const Rhs>, int(RhsUpLo)==Upper ? Lower : Upper, false,
Transpose<const Lhs>, 0, true>(m_rhs.transpose(), m_lhs.transpose()).scaleAndAddTo(destT, alpha); Transpose<const Lhs>, 0, true>::run(destT, a_rhs.transpose(), a_lhs.transpose(), alpha);
} }
}; };
} // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_H #endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_H
/*
Copyright (c) 2011, Intel Corporation. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Intel Corporation nor the names of its contributors may
be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to BLAS F77
* Selfadjoint matrix-vector product functionality based on ?SYMV/HEMV.
********************************************************************************
*/
#ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
#define EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
namespace Eigen {
namespace internal {
/**********************************************************************
* This file implements selfadjoint matrix-vector multiplication using BLAS
**********************************************************************/
// symv/hemv specialization
template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs>
struct selfadjoint_matrix_vector_product_symv :
selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn> {};
#define EIGEN_BLAS_SYMV_SPECIALIZE(Scalar) \
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
struct selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,Specialized> { \
static void run( \
Index size, const Scalar* lhs, Index lhsStride, \
const Scalar* _rhs, Scalar* res, Scalar alpha) { \
enum {\
IsColMajor = StorageOrder==ColMajor \
}; \
if (IsColMajor == ConjugateLhs) {\
selfadjoint_matrix_vector_product<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs,BuiltIn>::run( \
size, lhs, lhsStride, _rhs, res, alpha); \
} else {\
selfadjoint_matrix_vector_product_symv<Scalar,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs>::run( \
size, lhs, lhsStride, _rhs, res, alpha); \
}\
} \
}; \
EIGEN_BLAS_SYMV_SPECIALIZE(double)
EIGEN_BLAS_SYMV_SPECIALIZE(float)
EIGEN_BLAS_SYMV_SPECIALIZE(dcomplex)
EIGEN_BLAS_SYMV_SPECIALIZE(scomplex)
#define EIGEN_BLAS_SYMV_SPECIALIZATION(EIGTYPE,BLASTYPE,BLASFUNC) \
template<typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs> \
struct selfadjoint_matrix_vector_product_symv<EIGTYPE,Index,StorageOrder,UpLo,ConjugateLhs,ConjugateRhs> \
{ \
typedef Matrix<EIGTYPE,Dynamic,1,ColMajor> SYMVVector;\
\
static void run( \
Index size, const EIGTYPE* lhs, Index lhsStride, \
const EIGTYPE* _rhs, EIGTYPE* res, EIGTYPE alpha) \
{ \
enum {\
IsRowMajor = StorageOrder==RowMajor ? 1 : 0, \
IsLower = UpLo == Lower ? 1 : 0 \
}; \
BlasIndex n=convert_index<BlasIndex>(size), lda=convert_index<BlasIndex>(lhsStride), incx=1, incy=1; \
EIGTYPE beta(1); \
const EIGTYPE *x_ptr; \
char uplo=(IsRowMajor) ? (IsLower ? 'U' : 'L') : (IsLower ? 'L' : 'U'); \
SYMVVector x_tmp; \
if (ConjugateRhs) { \
Map<const SYMVVector, 0 > map_x(_rhs,size,1); \
x_tmp=map_x.conjugate(); \
x_ptr=x_tmp.data(); \
} else x_ptr=_rhs; \
BLASFUNC(&uplo, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)lhs, &lda, (const BLASTYPE*)x_ptr, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &incy); \
}\
};
#ifdef EIGEN_USE_MKL
EIGEN_BLAS_SYMV_SPECIALIZATION(double, double, dsymv)
EIGEN_BLAS_SYMV_SPECIALIZATION(float, float, ssymv)
EIGEN_BLAS_SYMV_SPECIALIZATION(dcomplex, MKL_Complex16, zhemv)
EIGEN_BLAS_SYMV_SPECIALIZATION(scomplex, MKL_Complex8, chemv)
#else
EIGEN_BLAS_SYMV_SPECIALIZATION(double, double, dsymv_)
EIGEN_BLAS_SYMV_SPECIALIZATION(float, float, ssymv_)
EIGEN_BLAS_SYMV_SPECIALIZATION(dcomplex, double, zhemv_)
EIGEN_BLAS_SYMV_SPECIALIZATION(scomplex, float, chemv_)
#endif
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_BLAS_H
...@@ -18,21 +18,19 @@ ...@@ -18,21 +18,19 @@
namespace Eigen { namespace Eigen {
template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs>
struct selfadjoint_rank1_update;
template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs> template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs> struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs>
{ {
static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha)
{ {
internal::conj_if<ConjRhs> cj; internal::conj_if<ConjRhs> cj;
typedef Map<const Matrix<Scalar,Dynamic,1> > OtherMap; typedef Map<const Matrix<Scalar,Dynamic,1> > OtherMap;
typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjRhsType; typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjLhsType;
for (Index i=0; i<size; ++i) for (Index i=0; i<size; ++i)
{ {
Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1))) Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1)))
+= (alpha * cj(vec[i])) * ConjRhsType(OtherMap(vec+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1))); += (alpha * cj(vecY[i])) * ConjLhsType(OtherMap(vecX+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1)));
} }
} }
}; };
...@@ -40,9 +38,9 @@ struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs> ...@@ -40,9 +38,9 @@ struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs>
template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs> template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
struct selfadjoint_rank1_update<Scalar,Index,RowMajor,UpLo,ConjLhs,ConjRhs> struct selfadjoint_rank1_update<Scalar,Index,RowMajor,UpLo,ConjLhs,ConjRhs>
{ {
static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha)
{ {
selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vec,alpha); selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vecY,vecX,alpha);
} }
}; };
...@@ -52,10 +50,9 @@ struct selfadjoint_product_selector; ...@@ -52,10 +50,9 @@ struct selfadjoint_product_selector;
template<typename MatrixType, typename OtherType, int UpLo> template<typename MatrixType, typename OtherType, int UpLo>
struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true> struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true>
{ {
static void run(MatrixType& mat, const OtherType& other, typename MatrixType::Scalar alpha) static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha)
{ {
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::Index Index;
typedef internal::blas_traits<OtherType> OtherBlasTraits; typedef internal::blas_traits<OtherType> OtherBlasTraits;
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType; typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType; typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType;
...@@ -78,17 +75,16 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true> ...@@ -78,17 +75,16 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true>
selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo, selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo,
OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
(!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex> (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex>
::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualAlpha); ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualOtherPtr, actualAlpha);
} }
}; };
template<typename MatrixType, typename OtherType, int UpLo> template<typename MatrixType, typename OtherType, int UpLo>
struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false> struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false>
{ {
static void run(MatrixType& mat, const OtherType& other, typename MatrixType::Scalar alpha) static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha)
{ {
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::Index Index;
typedef internal::blas_traits<OtherType> OtherBlasTraits; typedef internal::blas_traits<OtherType> OtherBlasTraits;
typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType; typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType; typedef typename internal::remove_all<ActualOtherType>::type _ActualOtherType;
...@@ -96,15 +92,27 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false> ...@@ -96,15 +92,27 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false>
Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived()); Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
enum { IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0 }; enum {
IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0,
OtherIsRowMajor = _ActualOtherType::Flags&RowMajorBit ? 1 : 0
};
Index size = mat.cols();
Index depth = actualOther.cols();
typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor,Scalar,Scalar,
MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime, _ActualOtherType::MaxColsAtCompileTime> BlockingType;
BlockingType blocking(size, size, depth, 1, false);
internal::general_matrix_matrix_triangular_product<Index, internal::general_matrix_matrix_triangular_product<Index,
Scalar, _ActualOtherType::Flags&RowMajorBit ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, Scalar, OtherIsRowMajor ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
Scalar, _ActualOtherType::Flags&RowMajorBit ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex, Scalar, OtherIsRowMajor ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex,
MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime, UpLo>
::run(mat.cols(), actualOther.cols(), ::run(size, depth,
&actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(),
mat.data(), mat.outerStride(), actualAlpha); mat.data(), mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
} }
}; };
...@@ -113,7 +121,7 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false> ...@@ -113,7 +121,7 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false>
template<typename MatrixType, unsigned int UpLo> template<typename MatrixType, unsigned int UpLo>
template<typename DerivedU> template<typename DerivedU>
SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo> SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
::rankUpdate(const MatrixBase<DerivedU>& u, Scalar alpha) ::rankUpdate(const MatrixBase<DerivedU>& u, const Scalar& alpha)
{ {
selfadjoint_product_selector<MatrixType,DerivedU,UpLo>::run(_expression().const_cast_derived(), u.derived(), alpha); selfadjoint_product_selector<MatrixType,DerivedU,UpLo>::run(_expression().const_cast_derived(), u.derived(), alpha);
......
...@@ -24,14 +24,14 @@ struct selfadjoint_rank2_update_selector; ...@@ -24,14 +24,14 @@ struct selfadjoint_rank2_update_selector;
template<typename Scalar, typename Index, typename UType, typename VType> template<typename Scalar, typename Index, typename UType, typename VType>
struct selfadjoint_rank2_update_selector<Scalar,Index,UType,VType,Lower> struct selfadjoint_rank2_update_selector<Scalar,Index,UType,VType,Lower>
{ {
static void run(Scalar* mat, Index stride, const UType& u, const VType& v, Scalar alpha) static void run(Scalar* mat, Index stride, const UType& u, const VType& v, const Scalar& alpha)
{ {
const Index size = u.size(); const Index size = u.size();
for (Index i=0; i<size; ++i) for (Index i=0; i<size; ++i)
{ {
Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+i, size-i) += Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+i, size-i) +=
(conj(alpha) * conj(u.coeff(i))) * v.tail(size-i) (numext::conj(alpha) * numext::conj(u.coeff(i))) * v.tail(size-i)
+ (alpha * conj(v.coeff(i))) * u.tail(size-i); + (alpha * numext::conj(v.coeff(i))) * u.tail(size-i);
} }
} }
}; };
...@@ -39,13 +39,13 @@ struct selfadjoint_rank2_update_selector<Scalar,Index,UType,VType,Lower> ...@@ -39,13 +39,13 @@ struct selfadjoint_rank2_update_selector<Scalar,Index,UType,VType,Lower>
template<typename Scalar, typename Index, typename UType, typename VType> template<typename Scalar, typename Index, typename UType, typename VType>
struct selfadjoint_rank2_update_selector<Scalar,Index,UType,VType,Upper> struct selfadjoint_rank2_update_selector<Scalar,Index,UType,VType,Upper>
{ {
static void run(Scalar* mat, Index stride, const UType& u, const VType& v, Scalar alpha) static void run(Scalar* mat, Index stride, const UType& u, const VType& v, const Scalar& alpha)
{ {
const Index size = u.size(); const Index size = u.size();
for (Index i=0; i<size; ++i) for (Index i=0; i<size; ++i)
Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i, i+1) += Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i, i+1) +=
(conj(alpha) * conj(u.coeff(i))) * v.head(i+1) (numext::conj(alpha) * numext::conj(u.coeff(i))) * v.head(i+1)
+ (alpha * conj(v.coeff(i))) * u.head(i+1); + (alpha * numext::conj(v.coeff(i))) * u.head(i+1);
} }
}; };
...@@ -58,7 +58,7 @@ template<bool Cond, typename T> struct conj_expr_if ...@@ -58,7 +58,7 @@ template<bool Cond, typename T> struct conj_expr_if
template<typename MatrixType, unsigned int UpLo> template<typename MatrixType, unsigned int UpLo>
template<typename DerivedU, typename DerivedV> template<typename DerivedU, typename DerivedV>
SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo> SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
::rankUpdate(const MatrixBase<DerivedU>& u, const MatrixBase<DerivedV>& v, Scalar alpha) ::rankUpdate(const MatrixBase<DerivedU>& u, const MatrixBase<DerivedV>& v, const Scalar& alpha)
{ {
typedef internal::blas_traits<DerivedU> UBlasTraits; typedef internal::blas_traits<DerivedU> UBlasTraits;
typedef typename UBlasTraits::DirectLinearAccessType ActualUType; typedef typename UBlasTraits::DirectLinearAccessType ActualUType;
...@@ -75,15 +75,15 @@ SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo> ...@@ -75,15 +75,15 @@ SelfAdjointView<MatrixType,UpLo>& SelfAdjointView<MatrixType,UpLo>
enum { IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0 }; enum { IsRowMajor = (internal::traits<MatrixType>::Flags&RowMajorBit) ? 1 : 0 };
Scalar actualAlpha = alpha * UBlasTraits::extractScalarFactor(u.derived()) Scalar actualAlpha = alpha * UBlasTraits::extractScalarFactor(u.derived())
* internal::conj(VBlasTraits::extractScalarFactor(v.derived())); * numext::conj(VBlasTraits::extractScalarFactor(v.derived()));
if (IsRowMajor) if (IsRowMajor)
actualAlpha = internal::conj(actualAlpha); actualAlpha = numext::conj(actualAlpha);
internal::selfadjoint_rank2_update_selector<Scalar, Index, typedef typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ UBlasTraits::NeedToConjugate,_ActualUType>::type>::type UType;
typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ UBlasTraits::NeedToConjugate,_ActualUType>::type>::type, typedef typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ VBlasTraits::NeedToConjugate,_ActualVType>::type>::type VType;
typename internal::remove_all<typename internal::conj_expr_if<IsRowMajor ^ VBlasTraits::NeedToConjugate,_ActualVType>::type>::type, internal::selfadjoint_rank2_update_selector<Scalar, Index, UType, VType,
(IsRowMajor ? int(UpLo==Upper ? Lower : Upper) : UpLo)> (IsRowMajor ? int(UpLo==Upper ? Lower : Upper) : UpLo)>
::run(_expression().const_cast_derived().data(),_expression().outerStride(),actualU,actualV,actualAlpha); ::run(_expression().const_cast_derived().data(),_expression().outerStride(),UType(actualU),VType(actualV),actualAlpha);
return *this; return *this;
} }
......
...@@ -45,23 +45,25 @@ template <typename Scalar, typename Index, ...@@ -45,23 +45,25 @@ template <typename Scalar, typename Index,
int Mode, bool LhsIsTriangular, int Mode, bool LhsIsTriangular,
int LhsStorageOrder, bool ConjugateLhs, int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs, int RhsStorageOrder, bool ConjugateRhs,
int ResStorageOrder, int Version = Specialized> int ResStorageOrder, int ResInnerStride,
int Version = Specialized>
struct product_triangular_matrix_matrix; struct product_triangular_matrix_matrix;
template <typename Scalar, typename Index, template <typename Scalar, typename Index,
int Mode, bool LhsIsTriangular, int Mode, bool LhsIsTriangular,
int LhsStorageOrder, bool ConjugateLhs, int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs, int Version> int RhsStorageOrder, bool ConjugateRhs,
int ResInnerStride, int Version>
struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular, struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular,
LhsStorageOrder,ConjugateLhs, LhsStorageOrder,ConjugateLhs,
RhsStorageOrder,ConjugateRhs,RowMajor,Version> RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride,Version>
{ {
static EIGEN_STRONG_INLINE void run( static EIGEN_STRONG_INLINE void run(
Index rows, Index cols, Index depth, Index rows, Index cols, Index depth,
const Scalar* lhs, Index lhsStride, const Scalar* lhs, Index lhsStride,
const Scalar* rhs, Index rhsStride, const Scalar* rhs, Index rhsStride,
Scalar* res, Index resStride, Scalar* res, Index resIncr, Index resStride,
Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{ {
product_triangular_matrix_matrix<Scalar, Index, product_triangular_matrix_matrix<Scalar, Index,
(Mode&(UnitDiag|ZeroDiag)) | ((Mode&Upper) ? Lower : Upper), (Mode&(UnitDiag|ZeroDiag)) | ((Mode&Upper) ? Lower : Upper),
...@@ -70,18 +72,19 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular, ...@@ -70,18 +72,19 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular,
ConjugateRhs, ConjugateRhs,
LhsStorageOrder==RowMajor ? ColMajor : RowMajor, LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
ConjugateLhs, ConjugateLhs,
ColMajor> ColMajor, ResInnerStride>
::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha, blocking); ::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resIncr, resStride, alpha, blocking);
} }
}; };
// implements col-major += alpha * op(triangular) * op(general) // implements col-major += alpha * op(triangular) * op(general)
template <typename Scalar, typename Index, int Mode, template <typename Scalar, typename Index, int Mode,
int LhsStorageOrder, bool ConjugateLhs, int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs, int Version> int RhsStorageOrder, bool ConjugateRhs,
int ResInnerStride, int Version>
struct product_triangular_matrix_matrix<Scalar,Index,Mode,true, struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
LhsStorageOrder,ConjugateLhs, LhsStorageOrder,ConjugateLhs,
RhsStorageOrder,ConjugateRhs,ColMajor,Version> RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
{ {
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar> Traits;
...@@ -95,8 +98,22 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true, ...@@ -95,8 +98,22 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
Index _rows, Index _cols, Index _depth, Index _rows, Index _cols, Index _depth,
const Scalar* _lhs, Index lhsStride, const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride, const Scalar* _rhs, Index rhsStride,
Scalar* res, Index resStride, Scalar* res, Index resIncr, Index resStride,
Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index, int Mode,
int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs,
int ResInnerStride, int Version>
EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,true,
LhsStorageOrder,ConjugateLhs,
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>::run(
Index _rows, Index _cols, Index _depth,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
Scalar* _res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{ {
// strip zeros // strip zeros
Index diagSize = (std::min)(_rows,_depth); Index diagSize = (std::min)(_rows,_depth);
...@@ -104,30 +121,42 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true, ...@@ -104,30 +121,42 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
Index depth = IsLower ? diagSize : _depth; Index depth = IsLower ? diagSize : _depth;
Index cols = _cols; Index cols = _cols;
const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride);
RhsMapper rhs(_rhs,rhsStride);
ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
// The small panel size must not be larger than blocking size.
// Usually this should never be the case because SmallPanelWidth^2 is very small
// compared to L2 cache size, but let's be safe:
Index panelWidth = (std::min)(Index(SmallPanelWidth),(std::min)(kc,mc));
std::size_t sizeA = kc*mc; std::size_t sizeA = kc*mc;
std::size_t sizeB = kc*cols; std::size_t sizeB = kc*cols;
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,LhsStorageOrder> triangularBuffer; // To work around an "error: member reference base type 'Matrix<...>
// (Eigen::internal::constructor_without_unaligned_array_assert (*)())' is
// not a structure or union" compilation error in nvcc (tested V8.0.61),
// create a dummy internal::constructor_without_unaligned_array_assert
// object to pass to the Matrix constructor.
internal::constructor_without_unaligned_array_assert a;
Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,LhsStorageOrder> triangularBuffer(a);
triangularBuffer.setZero(); triangularBuffer.setZero();
if((Mode&ZeroDiag)==ZeroDiag) if((Mode&ZeroDiag)==ZeroDiag)
triangularBuffer.diagonal().setZero(); triangularBuffer.diagonal().setZero();
else else
triangularBuffer.diagonal().setOnes(); triangularBuffer.diagonal().setOnes();
gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs; gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder> pack_rhs;
for(Index k2=IsLower ? depth : 0; for(Index k2=IsLower ? depth : 0;
IsLower ? k2>0 : k2<depth; IsLower ? k2>0 : k2<depth;
...@@ -143,7 +172,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true, ...@@ -143,7 +172,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
k2 = k2+actual_kc-kc; k2 = k2+actual_kc-kc;
} }
pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, actual_kc, cols); pack_rhs(blockB, rhs.getSubMapper(actual_k2,0), actual_kc, cols);
// the selected lhs's panel has to be split in three different parts: // the selected lhs's panel has to be split in three different parts:
// 1 - the part which is zero => skip it // 1 - the part which is zero => skip it
...@@ -154,9 +183,9 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true, ...@@ -154,9 +183,9 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
if(IsLower || actual_k2<rows) if(IsLower || actual_k2<rows)
{ {
// for each small vertical panels of lhs // for each small vertical panels of lhs
for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth) for (Index k1=0; k1<actual_kc; k1+=panelWidth)
{ {
Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth); Index actualPanelWidth = std::min<Index>(actual_kc-k1, panelWidth);
Index lengthTarget = IsLower ? actual_kc-k1-actualPanelWidth : k1; Index lengthTarget = IsLower ? actual_kc-k1-actualPanelWidth : k1;
Index startBlock = actual_k2+k1; Index startBlock = actual_k2+k1;
Index blockBOffset = k1; Index blockBOffset = k1;
...@@ -171,20 +200,22 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true, ...@@ -171,20 +200,22 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
for (Index i=IsLower ? k+1 : 0; IsLower ? i<actualPanelWidth : i<k; ++i) for (Index i=IsLower ? k+1 : 0; IsLower ? i<actualPanelWidth : i<k; ++i)
triangularBuffer.coeffRef(i,k) = lhs(startBlock+i,startBlock+k); triangularBuffer.coeffRef(i,k) = lhs(startBlock+i,startBlock+k);
} }
pack_lhs(blockA, triangularBuffer.data(), triangularBuffer.outerStride(), actualPanelWidth, actualPanelWidth); pack_lhs(blockA, LhsMapper(triangularBuffer.data(), triangularBuffer.outerStride()), actualPanelWidth, actualPanelWidth);
gebp_kernel(res+startBlock, resStride, blockA, blockB, actualPanelWidth, actualPanelWidth, cols, alpha, gebp_kernel(res.getSubMapper(startBlock, 0), blockA, blockB,
actualPanelWidth, actual_kc, 0, blockBOffset, blockW); actualPanelWidth, actualPanelWidth, cols, alpha,
actualPanelWidth, actual_kc, 0, blockBOffset);
// GEBP with remaining micro panel // GEBP with remaining micro panel
if (lengthTarget>0) if (lengthTarget>0)
{ {
Index startTarget = IsLower ? actual_k2+k1+actualPanelWidth : actual_k2; Index startTarget = IsLower ? actual_k2+k1+actualPanelWidth : actual_k2;
pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget); pack_lhs(blockA, lhs.getSubMapper(startTarget,startBlock), actualPanelWidth, lengthTarget);
gebp_kernel(res+startTarget, resStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, alpha, gebp_kernel(res.getSubMapper(startTarget, 0), blockA, blockB,
actualPanelWidth, actual_kc, 0, blockBOffset, blockW); lengthTarget, actualPanelWidth, cols, alpha,
actualPanelWidth, actual_kc, 0, blockBOffset);
} }
} }
} }
...@@ -195,23 +226,24 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true, ...@@ -195,23 +226,24 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
for(Index i2=start; i2<end; i2+=mc) for(Index i2=start; i2<end; i2+=mc)
{ {
const Index actual_mc = (std::min)(i2+mc,end)-i2; const Index actual_mc = (std::min)(i2+mc,end)-i2;
gemm_pack_lhs<Scalar, Index, Traits::mr,Traits::LhsProgress, LhsStorageOrder,false>() gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr,Traits::LhsProgress, LhsStorageOrder,false>()
(blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc); (blockA, lhs.getSubMapper(i2, actual_k2), actual_kc, actual_mc);
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1, -1, 0, 0, blockW); gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc,
actual_kc, cols, alpha, -1, -1, 0, 0);
} }
} }
} }
} }
};
// implements col-major += alpha * op(general) * op(triangular) // implements col-major += alpha * op(general) * op(triangular)
template <typename Scalar, typename Index, int Mode, template <typename Scalar, typename Index, int Mode,
int LhsStorageOrder, bool ConjugateLhs, int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs, int Version> int RhsStorageOrder, bool ConjugateRhs,
int ResInnerStride, int Version>
struct product_triangular_matrix_matrix<Scalar,Index,Mode,false, struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
LhsStorageOrder,ConjugateLhs, LhsStorageOrder,ConjugateLhs,
RhsStorageOrder,ConjugateRhs,ColMajor,Version> RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
{ {
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar> Traits;
enum { enum {
...@@ -224,40 +256,58 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false, ...@@ -224,40 +256,58 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
Index _rows, Index _cols, Index _depth, Index _rows, Index _cols, Index _depth,
const Scalar* _lhs, Index lhsStride, const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride, const Scalar* _rhs, Index rhsStride,
Scalar* res, Index resStride, Scalar* res, Index resIncr, Index resStride,
Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index, int Mode,
int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs,
int ResInnerStride, int Version>
EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,false,
LhsStorageOrder,ConjugateLhs,
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>::run(
Index _rows, Index _cols, Index _depth,
const Scalar* _lhs, Index lhsStride,
const Scalar* _rhs, Index rhsStride,
Scalar* _res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{ {
const Index PacketBytes = packet_traits<Scalar>::size*sizeof(Scalar);
// strip zeros // strip zeros
Index diagSize = (std::min)(_cols,_depth); Index diagSize = (std::min)(_cols,_depth);
Index rows = _rows; Index rows = _rows;
Index depth = IsLower ? _depth : diagSize; Index depth = IsLower ? _depth : diagSize;
Index cols = IsLower ? diagSize : _cols; Index cols = IsLower ? diagSize : _cols;
const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride);
RhsMapper rhs(_rhs,rhsStride);
ResMapper res(_res, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
std::size_t sizeA = kc*mc; std::size_t sizeA = kc*mc;
std::size_t sizeB = kc*cols; std::size_t sizeB = kc*cols+EIGEN_MAX_ALIGN_BYTES/sizeof(Scalar);
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,RhsStorageOrder> triangularBuffer; internal::constructor_without_unaligned_array_assert a;
Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,RhsStorageOrder> triangularBuffer(a);
triangularBuffer.setZero(); triangularBuffer.setZero();
if((Mode&ZeroDiag)==ZeroDiag) if((Mode&ZeroDiag)==ZeroDiag)
triangularBuffer.diagonal().setZero(); triangularBuffer.diagonal().setZero();
else else
triangularBuffer.diagonal().setOnes(); triangularBuffer.diagonal().setOnes();
gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel; gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs; gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs; gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder> pack_rhs;
gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder,false,true> pack_rhs_panel; gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder,false,true> pack_rhs_panel;
for(Index k2=IsLower ? 0 : depth; for(Index k2=IsLower ? 0 : depth;
IsLower ? k2<depth : k2>0; IsLower ? k2<depth : k2>0;
...@@ -279,8 +329,9 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false, ...@@ -279,8 +329,9 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
Index ts = (IsLower && actual_k2>=cols) ? 0 : actual_kc; Index ts = (IsLower && actual_k2>=cols) ? 0 : actual_kc;
Scalar* geb = blockB+ts*ts; Scalar* geb = blockB+ts*ts;
geb = geb + internal::first_aligned<PacketBytes>(geb,PacketBytes/sizeof(Scalar));
pack_rhs(geb, &rhs(actual_k2,IsLower ? 0 : k2), rhsStride, actual_kc, rs); pack_rhs(geb, rhs.getSubMapper(actual_k2,IsLower ? 0 : k2), actual_kc, rs);
// pack the triangular part of the rhs padding the unrolled blocks with zeros // pack the triangular part of the rhs padding the unrolled blocks with zeros
if(ts>0) if(ts>0)
...@@ -293,7 +344,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false, ...@@ -293,7 +344,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2; Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2;
// general part // general part
pack_rhs_panel(blockB+j2*actual_kc, pack_rhs_panel(blockB+j2*actual_kc,
&rhs(actual_k2+panelOffset, actual_j2), rhsStride, rhs.getSubMapper(actual_k2+panelOffset, actual_j2),
panelLength, actualPanelWidth, panelLength, actualPanelWidth,
actual_kc, panelOffset); actual_kc, panelOffset);
...@@ -307,7 +358,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false, ...@@ -307,7 +358,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
} }
pack_rhs_panel(blockB+j2*actual_kc, pack_rhs_panel(blockB+j2*actual_kc,
triangularBuffer.data(), triangularBuffer.outerStride(), RhsMapper(triangularBuffer.data(), triangularBuffer.outerStride()),
actualPanelWidth, actualPanelWidth, actualPanelWidth, actualPanelWidth,
actual_kc, j2); actual_kc, j2);
} }
...@@ -316,7 +367,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false, ...@@ -316,7 +367,7 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
for (Index i2=0; i2<rows; i2+=mc) for (Index i2=0; i2<rows; i2+=mc)
{ {
const Index actual_mc = (std::min)(mc,rows-i2); const Index actual_mc = (std::min)(mc,rows-i2);
pack_lhs(blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc); pack_lhs(blockA, lhs.getSubMapper(i2, actual_k2), actual_kc, actual_mc);
// triangular kernel // triangular kernel
if(ts>0) if(ts>0)
...@@ -327,50 +378,51 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false, ...@@ -327,50 +378,51 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
Index panelLength = IsLower ? actual_kc-j2 : j2+actualPanelWidth; Index panelLength = IsLower ? actual_kc-j2 : j2+actualPanelWidth;
Index blockOffset = IsLower ? j2 : 0; Index blockOffset = IsLower ? j2 : 0;
gebp_kernel(res+i2+(actual_k2+j2)*resStride, resStride, gebp_kernel(res.getSubMapper(i2, actual_k2 + j2),
blockA, blockB+j2*actual_kc, blockA, blockB+j2*actual_kc,
actual_mc, panelLength, actualPanelWidth, actual_mc, panelLength, actualPanelWidth,
alpha, alpha,
actual_kc, actual_kc, // strides actual_kc, actual_kc, // strides
blockOffset, blockOffset,// offsets blockOffset, blockOffset);// offsets
blockW); // workspace
} }
} }
gebp_kernel(res+i2+(IsLower ? 0 : k2)*resStride, resStride, gebp_kernel(res.getSubMapper(i2, IsLower ? 0 : k2),
blockA, geb, actual_mc, actual_kc, rs, blockA, geb, actual_mc, actual_kc, rs,
alpha, alpha,
-1, -1, 0, 0, blockW); -1, -1, 0, 0);
} }
} }
} }
};
/*************************************************************************** /***************************************************************************
* Wrapper to product_triangular_matrix_matrix * Wrapper to product_triangular_matrix_matrix
***************************************************************************/ ***************************************************************************/
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> >
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs> >
{};
} // end namespace internal } // end namespace internal
namespace internal {
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
: public ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs >
{ {
EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) template<typename Dest> static void run(Dest& dst, const Lhs &a_lhs, const Rhs &a_rhs, const typename Dest::Scalar& alpha)
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{ {
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs); typedef typename Lhs::Scalar LhsScalar;
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs); typedef typename Rhs::Scalar RhsScalar;
typedef typename Dest::Scalar Scalar;
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(a_lhs);
* RhsBlasTraits::extractScalarFactor(m_rhs); RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(a_rhs);
Scalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,4> BlockingType; Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,4> BlockingType;
...@@ -381,23 +433,40 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> ...@@ -381,23 +433,40 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
Index stripedDepth = LhsIsTriangular ? ((!IsLower) ? lhs.cols() : (std::min)(lhs.cols(),lhs.rows())) Index stripedDepth = LhsIsTriangular ? ((!IsLower) ? lhs.cols() : (std::min)(lhs.cols(),lhs.rows()))
: ((IsLower) ? rhs.rows() : (std::min)(rhs.rows(),rhs.cols())); : ((IsLower) ? rhs.rows() : (std::min)(rhs.rows(),rhs.cols()));
BlockingType blocking(stripedRows, stripedCols, stripedDepth); BlockingType blocking(stripedRows, stripedCols, stripedDepth, 1, false);
internal::product_triangular_matrix_matrix<Scalar, Index, internal::product_triangular_matrix_matrix<Scalar, Index,
Mode, LhsIsTriangular, Mode, LhsIsTriangular,
(internal::traits<_ActualLhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, (internal::traits<ActualLhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
(internal::traits<_ActualRhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, (internal::traits<ActualRhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
(internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor> (internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor, Dest::InnerStrideAtCompileTime>
::run( ::run(
stripedRows, stripedCols, stripedDepth, // sizes stripedRows, stripedCols, stripedDepth, // sizes
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
&rhs.coeffRef(0,0), rhs.outerStride(), // rhs info &rhs.coeffRef(0,0), rhs.outerStride(), // rhs info
&dst.coeffRef(0,0), dst.outerStride(), // result info &dst.coeffRef(0,0), dst.innerStride(), dst.outerStride(), // result info
actualAlpha, blocking actualAlpha, blocking
); );
// Apply correction if the diagonal is unit and a scalar factor was nested:
if ((Mode&UnitDiag)==UnitDiag)
{
if (LhsIsTriangular && lhs_alpha!=LhsScalar(1))
{
Index diagSize = (std::min)(lhs.rows(),lhs.cols());
dst.topRows(diagSize) -= ((lhs_alpha-LhsScalar(1))*a_rhs).topRows(diagSize);
}
else if ((!LhsIsTriangular) && rhs_alpha!=RhsScalar(1))
{
Index diagSize = (std::min)(rhs.rows(),rhs.cols());
dst.leftCols(diagSize) -= (rhs_alpha-RhsScalar(1))*a_lhs.leftCols(diagSize);
}
}
} }
}; };
} // end namespace internal
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_H #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_H
/*
Copyright (c) 2011, Intel Corporation. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Intel Corporation nor the names of its contributors may
be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to BLAS F77
* Triangular matrix * matrix product functionality based on ?TRMM.
********************************************************************************
*/
#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
#define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
namespace Eigen {
namespace internal {
template <typename Scalar, typename Index,
int Mode, bool LhsIsTriangular,
int LhsStorageOrder, bool ConjugateLhs,
int RhsStorageOrder, bool ConjugateRhs,
int ResStorageOrder>
struct product_triangular_matrix_matrix_trmm :
product_triangular_matrix_matrix<Scalar,Index,Mode,
LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
// try to go to BLAS specialization
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,1,Specialized> { \
static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
eigen_assert(resIncr == 1); \
product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
} \
};
EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true)
EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
// implements col-major += alpha * op(triangular) * op(general)
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
{ \
enum { \
IsLower = (Mode&Lower) == Lower, \
SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
LowUp = IsLower ? Lower : Upper, \
conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
}; \
\
static void run( \
Index _rows, Index _cols, Index _depth, \
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
{ \
Index diagSize = (std::min)(_rows,_depth); \
Index rows = IsLower ? _rows : diagSize; \
Index depth = IsLower ? diagSize : _depth; \
Index cols = _cols; \
\
typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
\
/* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
if (rows != depth) { \
\
/* FIXME handle mkl_domain_get_max_threads */ \
/*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\
\
if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
/* Most likely no benefit to call TRMM or GEMM from BLAS */ \
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
/*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
} else { \
/* Make sense to call GEMM */ \
Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, 1, resStride, alpha, gemm_blocking, 0); \
\
/*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
} \
return; \
} \
char side = 'L', transa, uplo, diag = 'N'; \
EIGTYPE *b; \
const EIGTYPE *a; \
BlasIndex m, n, lda, ldb; \
\
/* Set m, n */ \
m = convert_index<BlasIndex>(diagSize); \
n = convert_index<BlasIndex>(cols); \
\
/* Set trans */ \
transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
\
/* Set b, ldb */ \
Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
MatrixX##EIGPREFIX b_tmp; \
\
if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
b = b_tmp.data(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
\
/* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \
if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
/* Set a, lda */ \
Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
MatrixLhs a_tmp; \
\
if ((conjA!=0) || (SetDiag==0)) { \
if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
if (IsZeroDiag) \
a_tmp.diagonal().setZero(); \
else if (IsUnitDiag) \
a_tmp.diagonal().setOnes();\
a = a_tmp.data(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \
a = _lhs; \
lda = convert_index<BlasIndex>(lhsStride); \
} \
/*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
/* call ?trmm*/ \
BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\
/* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
res_tmp=res_tmp+b_tmp; \
} \
};
#ifdef EIGEN_USE_MKL
EIGEN_BLAS_TRMM_L(double, double, d, dtrmm)
EIGEN_BLAS_TRMM_L(dcomplex, MKL_Complex16, cd, ztrmm)
EIGEN_BLAS_TRMM_L(float, float, f, strmm)
EIGEN_BLAS_TRMM_L(scomplex, MKL_Complex8, cf, ctrmm)
#else
EIGEN_BLAS_TRMM_L(double, double, d, dtrmm_)
EIGEN_BLAS_TRMM_L(dcomplex, double, cd, ztrmm_)
EIGEN_BLAS_TRMM_L(float, float, f, strmm_)
EIGEN_BLAS_TRMM_L(scomplex, float, cf, ctrmm_)
#endif
// implements col-major += alpha * op(general) * op(triangular)
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
template <typename Index, int Mode, \
int LhsStorageOrder, bool ConjugateLhs, \
int RhsStorageOrder, bool ConjugateRhs> \
struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
{ \
enum { \
IsLower = (Mode&Lower) == Lower, \
SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
LowUp = IsLower ? Lower : Upper, \
conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
}; \
\
static void run( \
Index _rows, Index _cols, Index _depth, \
const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsStride, \
EIGTYPE* res, Index resStride, \
EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
{ \
Index diagSize = (std::min)(_cols,_depth); \
Index rows = _rows; \
Index depth = IsLower ? _depth : diagSize; \
Index cols = IsLower ? diagSize : _cols; \
\
typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
\
/* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
if (cols != depth) { \
\
int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
\
if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
/* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
/*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
} else { \
/* Make sense to call GEMM */ \
Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \
\
/*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
} \
return; \
} \
char side = 'R', transa, uplo, diag = 'N'; \
EIGTYPE *b; \
const EIGTYPE *a; \
BlasIndex m, n, lda, ldb; \
\
/* Set m, n */ \
m = convert_index<BlasIndex>(rows); \
n = convert_index<BlasIndex>(diagSize); \
\
/* Set trans */ \
transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
\
/* Set b, ldb */ \
Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
MatrixX##EIGPREFIX b_tmp; \
\
if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
b = b_tmp.data(); \
ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
\
/* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \
if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
/* Set a, lda */ \
Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
MatrixRhs a_tmp; \
\
if ((conjA!=0) || (SetDiag==0)) { \
if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
if (IsZeroDiag) \
a_tmp.diagonal().setZero(); \
else if (IsUnitDiag) \
a_tmp.diagonal().setOnes();\
a = a_tmp.data(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \
a = _rhs; \
lda = convert_index<BlasIndex>(rhsStride); \
} \
/*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
/* call ?trmm*/ \
BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
\
/* Add op(a_triangular)*b into res*/ \
Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
res_tmp=res_tmp+b_tmp; \
} \
};
#ifdef EIGEN_USE_MKL
EIGEN_BLAS_TRMM_R(double, double, d, dtrmm)
EIGEN_BLAS_TRMM_R(dcomplex, MKL_Complex16, cd, ztrmm)
EIGEN_BLAS_TRMM_R(float, float, f, strmm)
EIGEN_BLAS_TRMM_R(scomplex, MKL_Complex8, cf, ctrmm)
#else
EIGEN_BLAS_TRMM_R(double, double, d, dtrmm_)
EIGEN_BLAS_TRMM_R(dcomplex, double, cd, ztrmm_)
EIGEN_BLAS_TRMM_R(float, float, f, strmm_)
EIGEN_BLAS_TRMM_R(scomplex, float, cf, ctrmm_)
#endif
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#ifndef EIGEN_TRIANGULARMATRIXVECTOR_H #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
#define EIGEN_TRIANGULARMATRIXVECTOR_H #define EIGEN_TRIANGULARMATRIXVECTOR_H
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
...@@ -20,14 +20,20 @@ struct triangular_matrix_vector_product; ...@@ -20,14 +20,20 @@ struct triangular_matrix_vector_product;
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version> template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version> struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
{ {
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
enum { enum {
IsLower = ((Mode&Lower)==Lower), IsLower = ((Mode&Lower)==Lower),
HasUnitDiag = (Mode & UnitDiag)==UnitDiag, HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
}; };
static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha) const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha);
};
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha)
{ {
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
Index size = (std::min)(_rows,_cols); Index size = (std::min)(_rows,_cols);
...@@ -37,7 +43,7 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C ...@@ -37,7 +43,7 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap; typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap; typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr)); const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
...@@ -45,6 +51,9 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C ...@@ -45,6 +51,9 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap; typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
ResMap res(_res,rows); ResMap res(_res,rows);
typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
for (Index pi=0; pi<size; pi+=PanelWidth) for (Index pi=0; pi<size; pi+=PanelWidth)
{ {
Index actualPanelWidth = (std::min)(PanelWidth, size-pi); Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
...@@ -62,35 +71,40 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C ...@@ -62,35 +71,40 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
if (r>0) if (r>0)
{ {
Index s = IsLower ? pi+actualPanelWidth : 0; Index s = IsLower ? pi+actualPanelWidth : 0;
general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run( general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
r, actualPanelWidth, r, actualPanelWidth,
&lhs.coeffRef(s,pi), lhsStride, LhsMapper(&lhs.coeffRef(s,pi), lhsStride),
&rhs.coeffRef(pi), rhsIncr, RhsMapper(&rhs.coeffRef(pi), rhsIncr),
&res.coeffRef(s), resIncr, alpha); &res.coeffRef(s), resIncr, alpha);
} }
} }
if((!IsLower) && cols>size) if((!IsLower) && cols>size)
{ {
general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run( general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
rows, cols-size, rows, cols-size,
&lhs.coeffRef(0,size), lhsStride, LhsMapper(&lhs.coeffRef(0,size), lhsStride),
&rhs.coeffRef(size), rhsIncr, RhsMapper(&rhs.coeffRef(size), rhsIncr),
_res, resIncr, alpha); _res, resIncr, alpha);
} }
} }
};
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version> template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version> struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
{ {
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
enum { enum {
IsLower = ((Mode&Lower)==Lower), IsLower = ((Mode&Lower)==Lower),
HasUnitDiag = (Mode & UnitDiag)==UnitDiag, HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
}; };
static void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha) const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
};
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
{ {
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
Index diagSize = (std::min)(_rows,_cols); Index diagSize = (std::min)(_rows,_cols);
...@@ -107,7 +121,10 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C ...@@ -107,7 +121,10 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap; typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
ResMap res(_res,rows,InnerStride<>(resIncr)); ResMap res(_res,rows,InnerStride<>(resIncr));
typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
for (Index pi=0; pi<diagSize; pi+=PanelWidth) for (Index pi=0; pi<diagSize; pi+=PanelWidth)
{ {
Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi); Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
...@@ -125,105 +142,88 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C ...@@ -125,105 +142,88 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
if (r>0) if (r>0)
{ {
Index s = IsLower ? 0 : pi + actualPanelWidth; Index s = IsLower ? 0 : pi + actualPanelWidth;
general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run( general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs,BuiltIn>::run(
actualPanelWidth, r, actualPanelWidth, r,
&lhs.coeffRef(pi,s), lhsStride, LhsMapper(&lhs.coeffRef(pi,s), lhsStride),
&rhs.coeffRef(s), rhsIncr, RhsMapper(&rhs.coeffRef(s), rhsIncr),
&res.coeffRef(pi), resIncr, alpha); &res.coeffRef(pi), resIncr, alpha);
} }
} }
if(IsLower && rows>diagSize) if(IsLower && rows>diagSize)
{ {
general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run( general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjLhs,RhsScalar,RhsMapper,ConjRhs>::run(
rows-diagSize, cols, rows-diagSize, cols,
&lhs.coeffRef(diagSize,0), lhsStride, LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride),
&rhs.coeffRef(0), rhsIncr, RhsMapper(&rhs.coeffRef(0), rhsIncr),
&res.coeffRef(diagSize), resIncr, alpha); &res.coeffRef(diagSize), resIncr, alpha);
} }
} }
};
/*************************************************************************** /***************************************************************************
* Wrapper to product_triangular_vector * Wrapper to product_triangular_vector
***************************************************************************/ ***************************************************************************/
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs> template<int Mode,int StorageOrder>
struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
{};
template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
{};
template<int StorageOrder>
struct trmv_selector; struct trmv_selector;
} // end namespace internal } // end namespace internal
namespace internal {
template<int Mode, typename Lhs, typename Rhs> template<int Mode, typename Lhs, typename Rhs>
struct TriangularProduct<Mode,true,Lhs,false,Rhs,true> struct triangular_product_impl<Mode,true,Lhs,false,Rhs,true>
: public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
{ {
EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{ {
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha); internal::trmv_selector<Mode,(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(lhs, rhs, dst, alpha);
} }
}; };
template<int Mode, typename Lhs, typename Rhs> template<int Mode, typename Lhs, typename Rhs>
struct TriangularProduct<Mode,false,Lhs,true,Rhs,false> struct triangular_product_impl<Mode,false,Lhs,true,Rhs,false>
: public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
{ {
EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct) template<typename Dest> static void run(Dest& dst, const Lhs &lhs, const Rhs &rhs, const typename Dest::Scalar& alpha)
TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{ {
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); eigen_assert(dst.rows()==lhs.rows() && dst.cols()==rhs.cols());
typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
Transpose<Dest> dstT(dst); Transpose<Dest> dstT(dst);
internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run( internal::trmv_selector<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha); (int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
::run(rhs.transpose(),lhs.transpose(), dstT, alpha);
} }
}; };
} // end namespace internal
namespace internal { namespace internal {
// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same. // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
template<> struct trmv_selector<ColMajor> template<int Mode> struct trmv_selector<Mode,ColMajor>
{ {
template<int Mode, typename Lhs, typename Rhs, typename Dest> template<typename Lhs, typename Rhs, typename Dest>
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha) static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
{ {
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; typedef typename Lhs::Scalar LhsScalar;
typedef typename ProductType::Index Index; typedef typename Rhs::Scalar RhsScalar;
typedef typename ProductType::LhsScalar LhsScalar; typedef typename Dest::Scalar ResScalar;
typedef typename ProductType::RhsScalar RhsScalar; typedef typename Dest::RealScalar RealScalar;
typedef typename ProductType::Scalar ResScalar;
typedef typename ProductType::RealScalar RealScalar; typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename ProductType::ActualLhsType ActualLhsType; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ProductType::ActualRhsType ActualRhsType; typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename ProductType::LhsBlasTraits LhsBlasTraits; typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest; typedef Map<Matrix<ResScalar,Dynamic,1>, EIGEN_PLAIN_ENUM_MIN(AlignedMax,internal::packet_traits<ResScalar>::size)> MappedDest;
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
* RhsBlasTraits::extractScalarFactor(prod.rhs()); RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
enum { enum {
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1 // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
...@@ -235,9 +235,9 @@ template<> struct trmv_selector<ColMajor> ...@@ -235,9 +235,9 @@ template<> struct trmv_selector<ColMajor>
gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest; gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0)); bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha); RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
...@@ -246,7 +246,7 @@ template<> struct trmv_selector<ColMajor> ...@@ -246,7 +246,7 @@ template<> struct trmv_selector<ColMajor>
if(!evalToDest) if(!evalToDest)
{ {
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
int size = dest.size(); Index size = dest.size();
EIGEN_DENSE_STORAGE_CTOR_PLUGIN EIGEN_DENSE_STORAGE_CTOR_PLUGIN
#endif #endif
if(!alphaIsCompatible) if(!alphaIsCompatible)
...@@ -257,7 +257,7 @@ template<> struct trmv_selector<ColMajor> ...@@ -257,7 +257,7 @@ template<> struct trmv_selector<ColMajor>
else else
MappedDest(actualDestPtr, dest.size()) = dest; MappedDest(actualDestPtr, dest.size()) = dest;
} }
internal::triangular_matrix_vector_product internal::triangular_matrix_vector_product
<Index,Mode, <Index,Mode,
LhsScalar, LhsBlasTraits::NeedToConjugate, LhsScalar, LhsBlasTraits::NeedToConjugate,
...@@ -275,36 +275,42 @@ template<> struct trmv_selector<ColMajor> ...@@ -275,36 +275,42 @@ template<> struct trmv_selector<ColMajor>
else else
dest = MappedDest(actualDestPtr, dest.size()); dest = MappedDest(actualDestPtr, dest.size());
} }
if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
{
Index diagSize = (std::min)(lhs.rows(),lhs.cols());
dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
}
} }
}; };
template<> struct trmv_selector<RowMajor> template<int Mode> struct trmv_selector<Mode,RowMajor>
{ {
template<int Mode, typename Lhs, typename Rhs, typename Dest> template<typename Lhs, typename Rhs, typename Dest>
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha) static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
{ {
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType; typedef typename Lhs::Scalar LhsScalar;
typedef typename ProductType::LhsScalar LhsScalar; typedef typename Rhs::Scalar RhsScalar;
typedef typename ProductType::RhsScalar RhsScalar; typedef typename Dest::Scalar ResScalar;
typedef typename ProductType::Scalar ResScalar;
typedef typename ProductType::Index Index; typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename ProductType::ActualLhsType ActualLhsType; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ProductType::ActualRhsType ActualRhsType; typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename ProductType::_ActualRhsType _ActualRhsType; typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ProductType::LhsBlasTraits LhsBlasTraits; typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
* RhsBlasTraits::extractScalarFactor(prod.rhs()); ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
enum { enum {
DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1 DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
}; };
gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs; gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(), ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data()); DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
...@@ -312,12 +318,12 @@ template<> struct trmv_selector<RowMajor> ...@@ -312,12 +318,12 @@ template<> struct trmv_selector<RowMajor>
if(!DirectlyUseRhs) if(!DirectlyUseRhs)
{ {
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
int size = actualRhs.size(); Index size = actualRhs.size();
EIGEN_DENSE_STORAGE_CTOR_PLUGIN EIGEN_DENSE_STORAGE_CTOR_PLUGIN
#endif #endif
Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
} }
internal::triangular_matrix_vector_product internal::triangular_matrix_vector_product
<Index,Mode, <Index,Mode,
LhsScalar, LhsBlasTraits::NeedToConjugate, LhsScalar, LhsBlasTraits::NeedToConjugate,
...@@ -328,6 +334,12 @@ template<> struct trmv_selector<RowMajor> ...@@ -328,6 +334,12 @@ template<> struct trmv_selector<RowMajor>
actualRhsPtr,1, actualRhsPtr,1,
dest.data(),dest.innerStride(), dest.data(),dest.innerStride(),
actualAlpha); actualAlpha);
if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
{
Index diagSize = (std::min)(lhs.rows(),lhs.cols());
dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
}
} }
}; };
......
/*
Copyright (c) 2011, Intel Corporation. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Intel Corporation nor the names of its contributors may
be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to BLAS F77
* Triangular matrix-vector product functionality based on ?TRMV.
********************************************************************************
*/
#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
#define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
namespace Eigen {
namespace internal {
/**********************************************************************
* This file implements triangular matrix-vector multiplication using BLAS
**********************************************************************/
// trmv/hemv specialization
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder>
struct triangular_matrix_vector_product_trmv :
triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
_rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
} \
}; \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
_rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
} \
};
EIGEN_BLAS_TRMV_SPECIALIZE(double)
EIGEN_BLAS_TRMV_SPECIALIZE(float)
EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex)
EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
// implements col-major: res += alpha * op(triangular) * vector
#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
enum { \
IsLower = (Mode&Lower) == Lower, \
SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
LowUp = IsLower ? Lower : Upper \
}; \
static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
{ \
if (ConjLhs || IsZeroDiag) { \
triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
_rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
return; \
}\
Index size = (std::min)(_rows,_cols); \
Index rows = IsLower ? _rows : size; \
Index cols = IsLower ? size : _cols; \
\
typedef VectorX##EIGPREFIX VectorRhs; \
EIGTYPE *x, *y;\
\
/* Set x*/ \
Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
VectorRhs x_tmp; \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \
\
/* Square part handling */\
\
char trans, uplo, diag; \
BlasIndex m, n, lda, incx, incy; \
EIGTYPE const *a; \
EIGTYPE beta(1); \
\
/* Set m, n */ \
n = convert_index<BlasIndex>(size); \
lda = convert_index<BlasIndex>(lhsStride); \
incx = 1; \
incy = convert_index<BlasIndex>(resIncr); \
\
/* Set uplo, trans and diag*/ \
trans = 'N'; \
uplo = IsLower ? 'L' : 'U'; \
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \
if (size<rows) { \
y = _res + size*resIncr; \
a = _lhs + size; \
m = convert_index<BlasIndex>(rows-size); \
n = convert_index<BlasIndex>(size); \
} \
else { \
x += size; \
y = _res; \
a = _lhs + size*lda; \
m = convert_index<BlasIndex>(size); \
n = convert_index<BlasIndex>(cols-size); \
} \
BLASPREFIX##gemv##BLASPOSTFIX(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \
} \
};
#ifdef EIGEN_USE_MKL
EIGEN_BLAS_TRMV_CM(double, double, d, d,)
EIGEN_BLAS_TRMV_CM(dcomplex, MKL_Complex16, cd, z,)
EIGEN_BLAS_TRMV_CM(float, float, f, s,)
EIGEN_BLAS_TRMV_CM(scomplex, MKL_Complex8, cf, c,)
#else
EIGEN_BLAS_TRMV_CM(double, double, d, d, _)
EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z, _)
EIGEN_BLAS_TRMV_CM(float, float, f, s, _)
EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c, _)
#endif
// implements row-major: res += alpha * op(triangular) * vector
#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
enum { \
IsLower = (Mode&Lower) == Lower, \
SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
LowUp = IsLower ? Lower : Upper \
}; \
static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
{ \
if (IsZeroDiag) { \
triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
_rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
return; \
}\
Index size = (std::min)(_rows,_cols); \
Index rows = IsLower ? _rows : size; \
Index cols = IsLower ? size : _cols; \
\
typedef VectorX##EIGPREFIX VectorRhs; \
EIGTYPE *x, *y;\
\
/* Set x*/ \
Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
VectorRhs x_tmp; \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \
\
/* Square part handling */\
\
char trans, uplo, diag; \
BlasIndex m, n, lda, incx, incy; \
EIGTYPE const *a; \
EIGTYPE beta(1); \
\
/* Set m, n */ \
n = convert_index<BlasIndex>(size); \
lda = convert_index<BlasIndex>(lhsStride); \
incx = 1; \
incy = convert_index<BlasIndex>(resIncr); \
\
/* Set uplo, trans and diag*/ \
trans = ConjLhs ? 'C' : 'T'; \
uplo = IsLower ? 'U' : 'L'; \
diag = IsUnitDiag ? 'U' : 'N'; \
\
/* call ?TRMV*/ \
BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \
\
/* Add op(a_tr)rhs into res*/ \
BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \
if (size<rows) { \
y = _res + size*resIncr; \
a = _lhs + size*lda; \
m = convert_index<BlasIndex>(rows-size); \
n = convert_index<BlasIndex>(size); \
} \
else { \
x += size; \
y = _res; \
a = _lhs + size; \
m = convert_index<BlasIndex>(size); \
n = convert_index<BlasIndex>(cols-size); \
} \
BLASPREFIX##gemv##BLASPOSTFIX(&trans, &n, &m, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)y, &incy); \
} \
} \
};
#ifdef EIGEN_USE_MKL
EIGEN_BLAS_TRMV_RM(double, double, d, d,)
EIGEN_BLAS_TRMV_RM(dcomplex, MKL_Complex16, cd, z,)
EIGEN_BLAS_TRMV_RM(float, float, f, s,)
EIGEN_BLAS_TRMV_RM(scomplex, MKL_Complex8, cf, c,)
#else
EIGEN_BLAS_TRMV_RM(double, double, d, d,_)
EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z,_)
EIGEN_BLAS_TRMV_RM(float, float, f, s,_)
EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c,_)
#endif
} // end namespase internal
} // end namespace Eigen
#endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
...@@ -15,40 +15,51 @@ namespace Eigen { ...@@ -15,40 +15,51 @@ namespace Eigen {
namespace internal { namespace internal {
// if the rhs is row major, let's transpose the product // if the rhs is row major, let's transpose the product
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder> template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor> struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor,OtherInnerStride>
{ {
static EIGEN_DONT_INLINE void run( static void run(
Index size, Index cols, Index size, Index cols,
const Scalar* tri, Index triStride, const Scalar* tri, Index triStride,
Scalar* _other, Index otherStride, Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking) level3_blocking<Scalar,Scalar>& blocking)
{ {
triangular_solve_matrix< triangular_solve_matrix<
Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft, Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft,
(Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper), (Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper),
NumTraits<Scalar>::IsComplex && Conjugate, NumTraits<Scalar>::IsComplex && Conjugate,
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor> TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor, OtherInnerStride>
::run(size, cols, tri, triStride, _other, otherStride, blocking); ::run(size, cols, tri, triStride, _other, otherIncr, otherStride, blocking);
} }
}; };
/* Optimized triangular solver with multiple right hand side and the triangular matrix on the left /* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
*/ */
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder> template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>
{ {
static EIGEN_DONT_INLINE void run( static EIGEN_DONT_INLINE void run(
Index size, Index otherSize, Index size, Index otherSize,
const Scalar* _tri, Index triStride, const Scalar* _tri, Index triStride,
Scalar* _other, Index otherStride, Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking) level3_blocking<Scalar,Scalar>& blocking)
{ {
Index cols = otherSize; Index cols = otherSize;
const_blas_data_mapper<Scalar, Index, TriStorageOrder> tri(_tri,triStride);
blas_data_mapper<Scalar, Index, ColMajor> other(_other,otherStride); typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
TriMapper tri(_tri, triStride);
OtherMapper other(_other, otherStride, otherIncr);
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar> Traits;
enum { enum {
SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr), SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
IsLower = (Mode&Lower) == Lower IsLower = (Mode&Lower) == Lower
...@@ -59,22 +70,20 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO ...@@ -59,22 +70,20 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
std::size_t sizeA = kc*mc; std::size_t sizeA = kc*mc;
std::size_t sizeB = kc*cols; std::size_t sizeB = kc*cols;
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
conj_if<Conjugate> conj; conj_if<Conjugate> conj;
gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel; gebp_kernel<Scalar, Scalar, Index, OtherMapper, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, TriStorageOrder> pack_lhs; gemm_pack_lhs<Scalar, Index, TriMapper, Traits::mr, Traits::LhsProgress, TriStorageOrder> pack_lhs;
gemm_pack_rhs<Scalar, Index, Traits::nr, ColMajor, false, true> pack_rhs; gemm_pack_rhs<Scalar, Index, OtherMapper, Traits::nr, ColMajor, false, true> pack_rhs;
// the goal here is to subdivise the Rhs panels such that we keep some cache // the goal here is to subdivise the Rhs panels such that we keep some cache
// coherence when accessing the rhs elements // coherence when accessing the rhs elements
std::ptrdiff_t l1, l2; std::ptrdiff_t l1, l2, l3;
manage_caching_sizes(GetAction, &l1, &l2); manage_caching_sizes(GetAction, &l1, &l2, &l3);
Index subcols = cols>0 ? l2/(4 * sizeof(Scalar) * otherStride) : 0; Index subcols = cols>0 ? l2/(4 * sizeof(Scalar) * std::max<Index>(otherStride,size)) : 0;
subcols = std::max<Index>((subcols/Traits::nr)*Traits::nr, Traits::nr); subcols = std::max<Index>((subcols/Traits::nr)*Traits::nr, Traits::nr);
for(Index k2=IsLower ? 0 : size; for(Index k2=IsLower ? 0 : size;
...@@ -108,8 +117,9 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO ...@@ -108,8 +117,9 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
{ {
// TODO write a small kernel handling this (can be shared with trsv) // TODO write a small kernel handling this (can be shared with trsv)
Index i = IsLower ? k2+k1+k : k2-k1-k-1; Index i = IsLower ? k2+k1+k : k2-k1-k-1;
Index s = IsLower ? k2+k1 : i+1;
Index rs = actualPanelWidth - k - 1; // remaining size Index rs = actualPanelWidth - k - 1; // remaining size
Index s = TriStorageOrder==RowMajor ? (IsLower ? k2+k1 : i+1)
: IsLower ? i+1 : i-rs;
Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i)); Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
for (Index j=j2; j<j2+actual_cols; ++j) for (Index j=j2; j<j2+actual_cols; ++j)
...@@ -118,20 +128,19 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO ...@@ -118,20 +128,19 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
{ {
Scalar b(0); Scalar b(0);
const Scalar* l = &tri(i,s); const Scalar* l = &tri(i,s);
Scalar* r = &other(s,j); typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
for (Index i3=0; i3<k; ++i3) for (Index i3=0; i3<k; ++i3)
b += conj(l[i3]) * r[i3]; b += conj(l[i3]) * r(i3);
other(i,j) = (other(i,j) - b)*a; other(i,j) = (other(i,j) - b)*a;
} }
else else
{ {
Index s = IsLower ? i+1 : i-rs;
Scalar b = (other(i,j) *= a); Scalar b = (other(i,j) *= a);
Scalar* r = &other(s,j); typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
const Scalar* l = &tri(s,i); typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
for (Index i3=0;i3<rs;++i3) for (Index i3=0;i3<rs;++i3)
r[i3] -= b * conj(l[i3]); r(i3) -= b * conj(l(i3));
} }
} }
} }
...@@ -141,17 +150,17 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO ...@@ -141,17 +150,17 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
Index blockBOffset = IsLower ? k1 : lengthTarget; Index blockBOffset = IsLower ? k1 : lengthTarget;
// update the respective rows of B from other // update the respective rows of B from other
pack_rhs(blockB+actual_kc*j2, &other(startBlock,j2), otherStride, actualPanelWidth, actual_cols, actual_kc, blockBOffset); pack_rhs(blockB+actual_kc*j2, other.getSubMapper(startBlock,j2), actualPanelWidth, actual_cols, actual_kc, blockBOffset);
// GEBP // GEBP
if (lengthTarget>0) if (lengthTarget>0)
{ {
Index startTarget = IsLower ? k2+k1+actualPanelWidth : k2-actual_kc; Index startTarget = IsLower ? k2+k1+actualPanelWidth : k2-actual_kc;
pack_lhs(blockA, &tri(startTarget,startBlock), triStride, actualPanelWidth, lengthTarget); pack_lhs(blockA, tri.getSubMapper(startTarget,startBlock), actualPanelWidth, lengthTarget);
gebp_kernel(&other(startTarget,j2), otherStride, blockA, blockB+actual_kc*j2, lengthTarget, actualPanelWidth, actual_cols, Scalar(-1), gebp_kernel(other.getSubMapper(startTarget,j2), blockA, blockB+actual_kc*j2, lengthTarget, actualPanelWidth, actual_cols, Scalar(-1),
actualPanelWidth, actual_kc, 0, blockBOffset, blockW); actualPanelWidth, actual_kc, 0, blockBOffset);
} }
} }
} }
...@@ -165,30 +174,40 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO ...@@ -165,30 +174,40 @@ struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageO
const Index actual_mc = (std::min)(mc,end-i2); const Index actual_mc = (std::min)(mc,end-i2);
if (actual_mc>0) if (actual_mc>0)
{ {
pack_lhs(blockA, &tri(i2, IsLower ? k2 : k2-kc), triStride, actual_kc, actual_mc); pack_lhs(blockA, tri.getSubMapper(i2, IsLower ? k2 : k2-kc), actual_kc, actual_mc);
gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1), -1, -1, 0, 0, blockW); gebp_kernel(other.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1), -1, -1, 0, 0);
} }
} }
} }
} }
} }
};
/* Optimized triangular solver with multiple left hand sides and the trinagular matrix on the right /* Optimized triangular solver with multiple left hand sides and the triangular matrix on the right
*/ */
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder> template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>
{ {
static EIGEN_DONT_INLINE void run( static EIGEN_DONT_INLINE void run(
Index size, Index otherSize, Index size, Index otherSize,
const Scalar* _tri, Index triStride, const Scalar* _tri, Index triStride,
Scalar* _other, Index otherStride, Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking);
};
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
Index size, Index otherSize,
const Scalar* _tri, Index triStride,
Scalar* _other, Index otherIncr, Index otherStride,
level3_blocking<Scalar,Scalar>& blocking) level3_blocking<Scalar,Scalar>& blocking)
{ {
Index rows = otherSize; Index rows = otherSize;
const_blas_data_mapper<Scalar, Index, TriStorageOrder> rhs(_tri,triStride); typedef typename NumTraits<Scalar>::Real RealScalar;
blas_data_mapper<Scalar, Index, ColMajor> lhs(_other,otherStride);
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
LhsMapper lhs(_other, otherStride, otherIncr);
RhsMapper rhs(_tri, triStride);
typedef gebp_traits<Scalar,Scalar> Traits; typedef gebp_traits<Scalar,Scalar> Traits;
enum { enum {
...@@ -202,17 +221,15 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage ...@@ -202,17 +221,15 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
std::size_t sizeA = kc*mc; std::size_t sizeA = kc*mc;
std::size_t sizeB = kc*size; std::size_t sizeB = kc*size;
std::size_t sizeW = kc*Traits::WorkSpaceFactor;
ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA()); ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB()); ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
conj_if<Conjugate> conj; conj_if<Conjugate> conj;
gebp_kernel<Scalar,Scalar, Index, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel; gebp_kernel<Scalar, Scalar, Index, LhsMapper, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs; gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder,false,true> pack_rhs_panel; gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder,false,true> pack_rhs_panel;
gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, ColMajor, false, true> pack_lhs_panel; gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, ColMajor, false, true> pack_lhs_panel;
for(Index k2=IsLower ? size : 0; for(Index k2=IsLower ? size : 0;
IsLower ? k2>0 : k2<size; IsLower ? k2>0 : k2<size;
...@@ -225,7 +242,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage ...@@ -225,7 +242,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
Index rs = IsLower ? actual_k2 : size - actual_k2 - actual_kc; Index rs = IsLower ? actual_k2 : size - actual_k2 - actual_kc;
Scalar* geb = blockB+actual_kc*actual_kc; Scalar* geb = blockB+actual_kc*actual_kc;
if (rs>0) pack_rhs(geb, &rhs(actual_k2,startPanel), triStride, actual_kc, rs); if (rs>0) pack_rhs(geb, rhs.getSubMapper(actual_k2,startPanel), actual_kc, rs);
// triangular packing (we only pack the panels off the diagonal, // triangular packing (we only pack the panels off the diagonal,
// neglecting the blocks overlapping the diagonal // neglecting the blocks overlapping the diagonal
...@@ -239,7 +256,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage ...@@ -239,7 +256,7 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
if (panelLength>0) if (panelLength>0)
pack_rhs_panel(blockB+j2*actual_kc, pack_rhs_panel(blockB+j2*actual_kc,
&rhs(actual_k2+panelOffset, actual_j2), triStride, rhs.getSubMapper(actual_k2+panelOffset, actual_j2),
panelLength, actualPanelWidth, panelLength, actualPanelWidth,
actual_kc, panelOffset); actual_kc, panelOffset);
} }
...@@ -267,13 +284,12 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage ...@@ -267,13 +284,12 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
// GEBP // GEBP
if(panelLength>0) if(panelLength>0)
{ {
gebp_kernel(&lhs(i2,absolute_j2), otherStride, gebp_kernel(lhs.getSubMapper(i2,absolute_j2),
blockA, blockB+j2*actual_kc, blockA, blockB+j2*actual_kc,
actual_mc, panelLength, actualPanelWidth, actual_mc, panelLength, actualPanelWidth,
Scalar(-1), Scalar(-1),
actual_kc, actual_kc, // strides actual_kc, actual_kc, // strides
panelOffset, panelOffset, // offsets panelOffset, panelOffset); // offsets
blockW); // workspace
} }
// unblocked triangular solve // unblocked triangular solve
...@@ -281,34 +297,36 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage ...@@ -281,34 +297,36 @@ struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorage
{ {
Index j = IsLower ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k; Index j = IsLower ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
Scalar* r = &lhs(i2,j); typename LhsMapper::LinearMapper r = lhs.getLinearMapper(i2,j);
for (Index k3=0; k3<k; ++k3) for (Index k3=0; k3<k; ++k3)
{ {
Scalar b = conj(rhs(IsLower ? j+1+k3 : absolute_j2+k3,j)); Scalar b = conj(rhs(IsLower ? j+1+k3 : absolute_j2+k3,j));
Scalar* a = &lhs(i2,IsLower ? j+1+k3 : absolute_j2+k3); typename LhsMapper::LinearMapper a = lhs.getLinearMapper(i2,IsLower ? j+1+k3 : absolute_j2+k3);
for (Index i=0; i<actual_mc; ++i) for (Index i=0; i<actual_mc; ++i)
r[i] -= a[i] * b; r(i) -= a(i) * b;
}
if((Mode & UnitDiag)==0)
{
Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
for (Index i=0; i<actual_mc; ++i)
r(i) *= inv_rjj;
} }
Scalar b = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
for (Index i=0; i<actual_mc; ++i)
r[i] *= b;
} }
// pack the just computed part of lhs to A // pack the just computed part of lhs to A
pack_lhs_panel(blockA, _other+absolute_j2*otherStride+i2, otherStride, pack_lhs_panel(blockA, lhs.getSubMapper(i2,absolute_j2),
actualPanelWidth, actual_mc, actualPanelWidth, actual_mc,
actual_kc, j2); actual_kc, j2);
} }
} }
if (rs>0) if (rs>0)
gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb, gebp_kernel(lhs.getSubMapper(i2, startPanel), blockA, geb,
actual_mc, actual_kc, rs, Scalar(-1), actual_mc, actual_kc, rs, Scalar(-1),
-1, -1, 0, 0, blockW); -1, -1, 0, 0);
} }
} }
} }
};
} // end namespace internal } // end namespace internal
......
/*
Copyright (c) 2011, Intel Corporation. All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Intel Corporation nor the names of its contributors may
be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
********************************************************************************
* Content : Eigen bindings to BLAS F77
* Triangular matrix * matrix product functionality based on ?TRMM.
********************************************************************************
*/
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
namespace Eigen {
namespace internal {
// implements LeftSide op(triangular)^-1 * general
#define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASFUNC) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,1> \
{ \
enum { \
IsLower = (Mode&Lower) == Lower, \
IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
conjA = ((TriStorageOrder==ColMajor) && Conjugate) ? 1 : 0 \
}; \
static void run( \
Index size, Index otherSize, \
const EIGTYPE* _tri, Index triStride, \
EIGTYPE* _other, Index otherIncr, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
{ \
EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \
eigen_assert(otherIncr == 1); \
BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \
char side = 'L', uplo, diag='N', transa; \
/* Set alpha_ */ \
EIGTYPE alpha(1); \
ldb = convert_index<BlasIndex>(otherStride);\
\
const EIGTYPE *a; \
/* Set trans */ \
transa = (TriStorageOrder==RowMajor) ? ((Conjugate) ? 'C' : 'T') : 'N'; \
/* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \
if (TriStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
/* Set a, lda */ \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, TriStorageOrder> MatrixTri; \
Map<const MatrixTri, 0, OuterStride<> > tri(_tri,size,size,OuterStride<>(triStride)); \
MatrixTri a_tmp; \
\
if (conjA) { \
a_tmp = tri.conjugate(); \
a = a_tmp.data(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \
a = _tri; \
lda = convert_index<BlasIndex>(triStride); \
} \
if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \
BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
} \
};
#ifdef EIGEN_USE_MKL
EIGEN_BLAS_TRSM_L(double, double, dtrsm)
EIGEN_BLAS_TRSM_L(dcomplex, MKL_Complex16, ztrsm)
EIGEN_BLAS_TRSM_L(float, float, strsm)
EIGEN_BLAS_TRSM_L(scomplex, MKL_Complex8, ctrsm)
#else
EIGEN_BLAS_TRSM_L(double, double, dtrsm_)
EIGEN_BLAS_TRSM_L(dcomplex, double, ztrsm_)
EIGEN_BLAS_TRSM_L(float, float, strsm_)
EIGEN_BLAS_TRSM_L(scomplex, float, ctrsm_)
#endif
// implements RightSide general * op(triangular)^-1
#define EIGEN_BLAS_TRSM_R(EIGTYPE, BLASTYPE, BLASFUNC) \
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,1> \
{ \
enum { \
IsLower = (Mode&Lower) == Lower, \
IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
conjA = ((TriStorageOrder==ColMajor) && Conjugate) ? 1 : 0 \
}; \
static void run( \
Index size, Index otherSize, \
const EIGTYPE* _tri, Index triStride, \
EIGTYPE* _other, Index otherIncr, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
{ \
EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \
eigen_assert(otherIncr == 1); \
BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \
char side = 'R', uplo, diag='N', transa; \
/* Set alpha_ */ \
EIGTYPE alpha(1); \
ldb = convert_index<BlasIndex>(otherStride);\
\
const EIGTYPE *a; \
/* Set trans */ \
transa = (TriStorageOrder==RowMajor) ? ((Conjugate) ? 'C' : 'T') : 'N'; \
/* Set uplo */ \
uplo = IsLower ? 'L' : 'U'; \
if (TriStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
/* Set a, lda */ \
typedef Matrix<EIGTYPE, Dynamic, Dynamic, TriStorageOrder> MatrixTri; \
Map<const MatrixTri, 0, OuterStride<> > tri(_tri,size,size,OuterStride<>(triStride)); \
MatrixTri a_tmp; \
\
if (conjA) { \
a_tmp = tri.conjugate(); \
a = a_tmp.data(); \
lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
} else { \
a = _tri; \
lda = convert_index<BlasIndex>(triStride); \
} \
if (IsUnitDiag) diag='U'; \
/* call ?trsm*/ \
BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)_other, &ldb); \
/*std::cout << "TRMS_L specialization!\n";*/ \
} \
};
#ifdef EIGEN_USE_MKL
EIGEN_BLAS_TRSM_R(double, double, dtrsm)
EIGEN_BLAS_TRSM_R(dcomplex, MKL_Complex16, ztrsm)
EIGEN_BLAS_TRSM_R(float, float, strsm)
EIGEN_BLAS_TRSM_R(scomplex, MKL_Complex8, ctrsm)
#else
EIGEN_BLAS_TRSM_R(double, double, dtrsm_)
EIGEN_BLAS_TRSM_R(dcomplex, double, ztrsm_)
EIGEN_BLAS_TRSM_R(float, float, strsm_)
EIGEN_BLAS_TRSM_R(scomplex, float, ctrsm_)
#endif
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_BLAS_H
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#ifndef EIGEN_TRIANGULAR_SOLVER_VECTOR_H #ifndef EIGEN_TRIANGULAR_SOLVER_VECTOR_H
#define EIGEN_TRIANGULAR_SOLVER_VECTOR_H #define EIGEN_TRIANGULAR_SOLVER_VECTOR_H
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
...@@ -25,7 +25,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Co ...@@ -25,7 +25,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheRight, Mode, Co
>::run(size, _lhs, lhsStride, rhs); >::run(size, _lhs, lhsStride, rhs);
} }
}; };
// forward and backward substitution, row-major, rhs is a vector // forward and backward substitution, row-major, rhs is a vector
template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate> template<typename LhsScalar, typename RhsScalar, typename Index, int Mode, bool Conjugate>
struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, RowMajor> struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Conjugate, RowMajor>
...@@ -37,6 +37,10 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con ...@@ -37,6 +37,10 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con
{ {
typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap; typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride)); const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
typename internal::conditional< typename internal::conditional<
Conjugate, Conjugate,
const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>, const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
...@@ -58,10 +62,10 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con ...@@ -58,10 +62,10 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con
Index startRow = IsLower ? pi : pi-actualPanelWidth; Index startRow = IsLower ? pi : pi-actualPanelWidth;
Index startCol = IsLower ? 0 : pi; Index startCol = IsLower ? 0 : pi;
general_matrix_vector_product<Index,LhsScalar,RowMajor,Conjugate,RhsScalar,false>::run( general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,Conjugate,RhsScalar,RhsMapper,false>::run(
actualPanelWidth, r, actualPanelWidth, r,
&lhs.coeffRef(startRow,startCol), lhsStride, LhsMapper(&lhs.coeffRef(startRow,startCol), lhsStride),
rhs + startCol, 1, RhsMapper(rhs + startCol, 1),
rhs + startRow, 1, rhs + startRow, 1,
RhsScalar(-1)); RhsScalar(-1));
} }
...@@ -72,7 +76,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con ...@@ -72,7 +76,7 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con
Index s = IsLower ? pi : i+1; Index s = IsLower ? pi : i+1;
if (k>0) if (k>0)
rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map<const Matrix<RhsScalar,Dynamic,1> >(rhs+s,k))).sum(); rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map<const Matrix<RhsScalar,Dynamic,1> >(rhs+s,k))).sum();
if(!(Mode & UnitDiag)) if(!(Mode & UnitDiag))
rhs[i] /= cjLhs(i,i); rhs[i] /= cjLhs(i,i);
} }
...@@ -91,6 +95,8 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con ...@@ -91,6 +95,8 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con
{ {
typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap; typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride)); const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride));
typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
typedef const_blas_data_mapper<RhsScalar,Index,ColMajor> RhsMapper;
typename internal::conditional<Conjugate, typename internal::conditional<Conjugate,
const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>, const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
const LhsMap& const LhsMap&
...@@ -122,10 +128,10 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con ...@@ -122,10 +128,10 @@ struct triangular_solve_vector<LhsScalar, RhsScalar, Index, OnTheLeft, Mode, Con
// let's directly call the low level product function because: // let's directly call the low level product function because:
// 1 - it is faster to compile // 1 - it is faster to compile
// 2 - it is slighlty faster at runtime // 2 - it is slighlty faster at runtime
general_matrix_vector_product<Index,LhsScalar,ColMajor,Conjugate,RhsScalar,false>::run( general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,Conjugate,RhsScalar,RhsMapper,false>::run(
r, actualPanelWidth, r, actualPanelWidth,
&lhs.coeffRef(endBlock,startBlock), lhsStride, LhsMapper(&lhs.coeffRef(endBlock,startBlock), lhsStride),
rhs+startBlock, 1, RhsMapper(rhs+startBlock, 1),
rhs+endBlock, 1, RhsScalar(-1)); rhs+endBlock, 1, RhsScalar(-1));
} }
} }
......
...@@ -18,23 +18,25 @@ namespace Eigen { ...@@ -18,23 +18,25 @@ namespace Eigen {
namespace internal { namespace internal {
// forward declarations // forward declarations
template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false> template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
struct gebp_kernel; struct gebp_kernel;
template<typename Scalar, typename Index, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false> template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
struct gemm_pack_rhs; struct gemm_pack_rhs;
template<typename Scalar, typename Index, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false> template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
struct gemm_pack_lhs; struct gemm_pack_lhs;
template< template<
typename Index, typename Index,
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
int ResStorageOrder> int ResStorageOrder, int ResInnerStride>
struct general_matrix_matrix_product; struct general_matrix_matrix_product;
template<typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, bool ConjugateRhs, int Version=Specialized> template<typename Index,
typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs,
typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized>
struct general_matrix_vector_product; struct general_matrix_vector_product;
...@@ -42,22 +44,35 @@ template<bool Conjugate> struct conj_if; ...@@ -42,22 +44,35 @@ template<bool Conjugate> struct conj_if;
template<> struct conj_if<true> { template<> struct conj_if<true> {
template<typename T> template<typename T>
inline T operator()(const T& x) { return conj(x); } inline T operator()(const T& x) const { return numext::conj(x); }
template<typename T> template<typename T>
inline T pconj(const T& x) { return internal::pconj(x); } inline T pconj(const T& x) const { return internal::pconj(x); }
}; };
template<> struct conj_if<false> { template<> struct conj_if<false> {
template<typename T> template<typename T>
inline const T& operator()(const T& x) { return x; } inline const T& operator()(const T& x) const { return x; }
template<typename T> template<typename T>
inline const T& pconj(const T& x) { return x; } inline const T& pconj(const T& x) const { return x; }
};
// Generic implementation for custom complex types.
template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs>
struct conj_helper
{
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar;
EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar& x, const RhsScalar& y, const Scalar& c) const
{ return padd(c, pmul(x,y)); }
EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar& x, const RhsScalar& y) const
{ return conj_if<ConjLhs>()(x) * conj_if<ConjRhs>()(y); }
}; };
template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false> template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
{ {
EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
}; };
template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true> template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
...@@ -67,7 +82,7 @@ template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std:: ...@@ -67,7 +82,7 @@ template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::
{ return c + pmul(x,y); } { return c + pmul(x,y); }
EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
{ return Scalar(real(x)*real(y) + imag(x)*imag(y), imag(x)*real(y) - real(x)*imag(y)); } { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
}; };
template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false> template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
...@@ -77,7 +92,7 @@ template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std:: ...@@ -77,7 +92,7 @@ template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::
{ return c + pmul(x,y); } { return c + pmul(x,y); }
EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
{ return Scalar(real(x)*real(y) + imag(x)*imag(y), real(x)*imag(y) - imag(x)*real(y)); } { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
}; };
template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true> template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
...@@ -87,7 +102,7 @@ template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std:: ...@@ -87,7 +102,7 @@ template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::
{ return c + pmul(x,y); } { return c + pmul(x,y); }
EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
{ return Scalar(real(x)*real(y) - imag(x)*imag(y), - real(x)*imag(y) - imag(x)*real(y)); } { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
}; };
template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false> template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
...@@ -109,39 +124,243 @@ template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::comp ...@@ -109,39 +124,243 @@ template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::comp
}; };
template<typename From,typename To> struct get_factor { template<typename From,typename To> struct get_factor {
static EIGEN_STRONG_INLINE To run(const From& x) { return x; } EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); }
}; };
template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> { template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return real(x); } EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); }
};
template<typename Scalar, typename Index>
class BlasVectorMapper {
public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
return m_data[i];
}
template <typename Packet, int AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const {
return ploadt<Packet, AlignmentType>(m_data + i);
}
template <typename Packet>
EIGEN_DEVICE_FUNC bool aligned(Index i) const {
return (UIntPtr(m_data+i)%sizeof(Packet))==0;
}
protected:
Scalar* m_data;
};
template<typename Scalar, typename Index, int AlignmentType, int Incr=1>
class BlasLinearMapper;
template<typename Scalar, typename Index, int AlignmentType>
class BlasLinearMapper<Scalar,Index,AlignmentType,1> {
public:
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data, Index incr=1)
: m_data(data)
{
EIGEN_ONLY_USED_FOR_DEBUG(incr);
eigen_assert(incr==1);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
internal::prefetch(&operator()(i));
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
return m_data[i];
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
return ploadt<Packet, AlignmentType>(m_data + i);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
return ploadt<HalfPacket, AlignmentType>(m_data + i);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const {
pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
}
protected:
Scalar *m_data;
}; };
// Lightweight helper class to access matrix coefficients. // Lightweight helper class to access matrix coefficients.
// Yes, this is somehow redundant with Map<>, but this version is much much lighter, template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned, int Incr = 1>
// and so I hope better compilation performance (time and code quality). class blas_data_mapper;
template<typename Scalar, typename Index, int StorageOrder>
class blas_data_mapper template<typename Scalar, typename Index, int StorageOrder, int AlignmentType>
class blas_data_mapper<Scalar,Index,StorageOrder,AlignmentType,1>
{ {
public: public:
blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {} typedef typename packet_traits<Scalar>::type Packet;
EIGEN_STRONG_INLINE Scalar& operator()(Index i, Index j) typedef typename packet_traits<Scalar>::half HalfPacket;
{ return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; }
typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
typedef BlasVectorMapper<Scalar, Index> VectorMapper;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1)
: m_data(data), m_stride(stride)
{
EIGEN_ONLY_USED_FOR_DEBUG(incr);
eigen_assert(incr==1);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
getSubMapper(Index i, Index j) const {
return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
return LinearMapper(&operator()(i, j));
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
return VectorMapper(&operator()(i, j));
}
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride];
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
return ploadt<Packet, AlignmentType>(&operator()(i, j));
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
return ploadt<HalfPacket, AlignmentType>(&operator()(i, j));
}
template<typename SubPacket>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
}
template<typename SubPacket>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
}
EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; }
EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
if (UIntPtr(m_data)%sizeof(Scalar)) {
return -1;
}
return internal::first_default_aligned(m_data, size);
}
protected: protected:
Scalar* EIGEN_RESTRICT m_data; Scalar* EIGEN_RESTRICT m_data;
Index m_stride; const Index m_stride;
};
// Implementation of non-natural increment (i.e. inner-stride != 1)
// The exposed API is not complete yet compared to the Incr==1 case
// because some features makes less sense in this case.
template<typename Scalar, typename Index, int AlignmentType, int Incr>
class BlasLinearMapper
{
public:
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data,Index incr) : m_data(data), m_incr(incr) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
internal::prefetch(&operator()(i));
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
return m_data[i*m_incr.value()];
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
return pgather<Scalar,Packet>(m_data + i*m_incr.value(), m_incr.value());
}
template<typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType &p) const {
pscatter<Scalar, PacketType>(m_data + i*m_incr.value(), p, m_incr.value());
}
protected:
Scalar *m_data;
const internal::variable_if_dynamic<Index,Incr> m_incr;
};
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType,int Incr>
class blas_data_mapper
{
public:
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper
getSubMapper(Index i, Index j) const {
return blas_data_mapper(&operator()(i, j), m_stride, m_incr.value());
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
return LinearMapper(&operator()(i, j), m_incr.value());
}
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
return m_data[StorageOrder==RowMajor ? j*m_incr.value() + i*m_stride : i*m_incr.value() + j*m_stride];
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
return pgather<Scalar,Packet>(&operator()(i, j),m_incr.value());
}
template <typename PacketT, int AlignmentT>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const {
return pgather<Scalar,PacketT>(&operator()(i, j),m_incr.value());
}
template<typename SubPacket>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
}
template<typename SubPacket>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
}
protected:
Scalar* EIGEN_RESTRICT m_data;
const Index m_stride;
const internal::variable_if_dynamic<Index,Incr> m_incr;
}; };
// lightweight helper class to access matrix coefficients (const version) // lightweight helper class to access matrix coefficients (const version)
template<typename Scalar, typename Index, int StorageOrder> template<typename Scalar, typename Index, int StorageOrder>
class const_blas_data_mapper class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
{
public: public:
const_blas_data_mapper(const Scalar* data, Index stride) : m_data(data), m_stride(stride) {} EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
EIGEN_STRONG_INLINE const Scalar& operator()(Index i, Index j) const
{ return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; } EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
protected: return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
const Scalar* EIGEN_RESTRICT m_data; }
Index m_stride;
}; };
...@@ -188,17 +407,33 @@ struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> > ...@@ -188,17 +407,33 @@ struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
}; };
// pop scalar multiple // pop scalar multiple
template<typename Scalar, typename NestedXpr> template<typename Scalar, typename NestedXpr, typename Plain>
struct blas_traits<CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> > struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
: blas_traits<NestedXpr> : blas_traits<NestedXpr>
{ {
typedef blas_traits<NestedXpr> Base; typedef blas_traits<NestedXpr> Base;
typedef CwiseUnaryOp<scalar_multiple_op<Scalar>, NestedXpr> XprType; typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
typedef typename Base::ExtractType ExtractType; typedef typename Base::ExtractType ExtractType;
static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); } static inline ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); }
static inline Scalar extractScalarFactor(const XprType& x)
{ return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); }
};
template<typename Scalar, typename NestedXpr, typename Plain>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
: blas_traits<NestedXpr>
{
typedef blas_traits<NestedXpr> Base;
typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
typedef typename Base::ExtractType ExtractType;
static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); }
static inline Scalar extractScalarFactor(const XprType& x) static inline Scalar extractScalarFactor(const XprType& x)
{ return x.functor().m_other * Base::extractScalarFactor(x.nestedExpression()); } { return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; }
}; };
template<typename Scalar, typename Plain1, typename Plain2>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>,
const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > >
: blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> >
{};
// pop opposite // pop opposite
template<typename Scalar, typename NestedXpr> template<typename Scalar, typename NestedXpr>
...@@ -230,7 +465,7 @@ struct blas_traits<Transpose<NestedXpr> > ...@@ -230,7 +465,7 @@ struct blas_traits<Transpose<NestedXpr> >
enum { enum {
IsTransposed = Base::IsTransposed ? 0 : 1 IsTransposed = Base::IsTransposed ? 0 : 1
}; };
static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); } static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); } static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment