Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
lietorch
Commits
266d4fd9
Commit
266d4fd9
authored
Jun 03, 2025
by
zhanggzh
Browse files
add lietorch src code and eigen src code, update readme
parent
e7df8655
Changes
148
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
7776 additions
and
0 deletions
+7776
-0
eigen-master/Eigen/src/Core/arch/AVX512/MathFunctions.h
eigen-master/Eigen/src/Core/arch/AVX512/MathFunctions.h
+141
-0
eigen-master/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h
eigen-master/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h
+75
-0
eigen-master/Eigen/src/Core/arch/AVX512/PacketMath.h
eigen-master/Eigen/src/Core/arch/AVX512/PacketMath.h
+3354
-0
eigen-master/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
eigen-master/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
+1413
-0
eigen-master/Eigen/src/Core/arch/AVX512/TrsmKernel.h
eigen-master/Eigen/src/Core/arch/AVX512/TrsmKernel.h
+1167
-0
eigen-master/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc
eigen-master/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc
+1219
-0
eigen-master/Eigen/src/Core/arch/AVX512/TypeCasting.h
eigen-master/Eigen/src/Core/arch/AVX512/TypeCasting.h
+277
-0
eigen-master/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h
eigen-master/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h
+130
-0
No files found.
Too many changes to show.
To preserve performance only
148 of 148+
files are displayed.
Plain diff
Email patch
eigen-master/Eigen/src/Core/arch/AVX512/MathFunctions.h
0 → 100644
View file @
266d4fd9
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2016 Pedro Gonnet (pedro.gonnet@gmail.com)
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_
#define THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace
Eigen
{
namespace
internal
{
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_FLOAT
(
Packet16f
)
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE
(
Packet8d
)
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pfrexp
(
const
Packet16h
&
a
,
Packet16h
&
exponent
)
{
Packet16f
fexponent
;
const
Packet16h
out
=
float2half
(
pfrexp
<
Packet16f
>
(
half2float
(
a
),
fexponent
));
exponent
=
float2half
(
fexponent
);
return
out
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pldexp
(
const
Packet16h
&
a
,
const
Packet16h
&
exponent
)
{
return
float2half
(
pldexp
<
Packet16f
>
(
half2float
(
a
),
half2float
(
exponent
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pfrexp
(
const
Packet16bf
&
a
,
Packet16bf
&
exponent
)
{
Packet16f
fexponent
;
const
Packet16bf
out
=
F32ToBf16
(
pfrexp
<
Packet16f
>
(
Bf16ToF32
(
a
),
fexponent
));
exponent
=
F32ToBf16
(
fexponent
);
return
out
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pldexp
(
const
Packet16bf
&
a
,
const
Packet16bf
&
exponent
)
{
return
F32ToBf16
(
pldexp
<
Packet16f
>
(
Bf16ToF32
(
a
),
Bf16ToF32
(
exponent
)));
}
#if EIGEN_FAST_MATH
template
<
>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet16f
psqrt
<
Packet16f
>
(
const
Packet16f
&
x
)
{
return
generic_sqrt_newton_step
<
Packet16f
>::
run
(
x
,
_mm512_rsqrt14_ps
(
x
));
}
template
<
>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet8d
psqrt
<
Packet8d
>
(
const
Packet8d
&
x
)
{
#ifdef EIGEN_VECTORIZE_AVX512ER
return
generic_sqrt_newton_step
<
Packet8d
,
/*Steps=*/
1
>::
run
(
x
,
_mm512_rsqrt28_pd
(
x
));
#else
return
generic_sqrt_newton_step
<
Packet8d
,
/*Steps=*/
2
>::
run
(
x
,
_mm512_rsqrt14_pd
(
x
));
#endif
}
#else
template
<
>
EIGEN_STRONG_INLINE
Packet16f
psqrt
<
Packet16f
>
(
const
Packet16f
&
x
)
{
return
_mm512_sqrt_ps
(
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
psqrt
<
Packet8d
>
(
const
Packet8d
&
x
)
{
return
_mm512_sqrt_pd
(
x
);
}
#endif
// prsqrt for float.
#if defined(EIGEN_VECTORIZE_AVX512ER)
template
<
>
EIGEN_STRONG_INLINE
Packet16f
prsqrt
<
Packet16f
>
(
const
Packet16f
&
x
)
{
return
_mm512_rsqrt28_ps
(
x
);
}
#elif EIGEN_FAST_MATH
template
<
>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet16f
prsqrt
<
Packet16f
>
(
const
Packet16f
&
x
)
{
return
generic_rsqrt_newton_step
<
Packet16f
,
/*Steps=*/
1
>::
run
(
x
,
_mm512_rsqrt14_ps
(
x
));
}
#endif
// prsqrt for double.
#if EIGEN_FAST_MATH
template
<
>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet8d
prsqrt
<
Packet8d
>
(
const
Packet8d
&
x
)
{
#ifdef EIGEN_VECTORIZE_AVX512ER
return
generic_rsqrt_newton_step
<
Packet8d
,
/*Steps=*/
1
>::
run
(
x
,
_mm512_rsqrt28_pd
(
x
));
#else
return
generic_rsqrt_newton_step
<
Packet8d
,
/*Steps=*/
2
>::
run
(
x
,
_mm512_rsqrt14_pd
(
x
));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
preciprocal
<
Packet16f
>
(
const
Packet16f
&
a
)
{
#ifdef EIGEN_VECTORIZE_AVX512ER
return
_mm512_rcp28_ps
(
a
);
#else
return
generic_reciprocal_newton_step
<
Packet16f
,
/*Steps=*/
1
>::
run
(
a
,
_mm512_rcp14_ps
(
a
));
#endif
}
#endif
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
pcos
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
pexp
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
pexp2
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
pexpm1
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
plog
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
plog1p
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
plog2
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
preciprocal
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
prsqrt
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
psin
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
psqrt
)
BF16_PACKET_FUNCTION
(
Packet16f
,
Packet16bf
,
ptanh
)
#ifndef EIGEN_VECTORIZE_AVX512FP16
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
pcos
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
pexp
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
pexp2
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
pexpm1
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
plog
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
plog1p
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
plog2
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
preciprocal
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
prsqrt
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
psin
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
psqrt
)
F16_PACKET_FUNCTION
(
Packet16f
,
Packet16h
,
ptanh
)
#endif // EIGEN_VECTORIZE_AVX512FP16
}
// end namespace internal
}
// end namespace Eigen
#endif // THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_
eigen-master/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h
0 → 100644
View file @
266d4fd9
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
#define EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace
Eigen
{
namespace
internal
{
EIGEN_STRONG_INLINE
Packet32h
combine2Packet16h
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
__m512i
result
=
_mm512_castsi256_si512
(
_mm256_castph_si256
(
a
));
result
=
_mm512_inserti64x4
(
result
,
_mm256_castph_si256
(
b
),
1
);
return
_mm512_castsi512_ph
(
result
);
}
EIGEN_STRONG_INLINE
void
extract2Packet16h
(
const
Packet32h
&
x
,
Packet16h
&
a
,
Packet16h
&
b
)
{
a
=
_mm256_castsi256_ph
(
_mm512_castsi512_si256
(
_mm512_castph_si512
(
x
)));
b
=
_mm256_castsi256_ph
(
_mm512_extracti64x4_epi64
(
_mm512_castph_si512
(
x
),
1
));
}
#define _EIGEN_GENERATE_FP16_MATH_FUNCTION(func) \
template <> \
EIGEN_STRONG_INLINE Packet8h func<Packet8h>(const Packet8h& a) { \
return float2half(func(half2float(a))); \
} \
\
template <> \
EIGEN_STRONG_INLINE Packet16h func<Packet16h>(const Packet16h& a) { \
return float2half(func(half2float(a))); \
} \
\
template <> \
EIGEN_STRONG_INLINE Packet32h func<Packet32h>(const Packet32h& a) { \
Packet16h low; \
Packet16h high; \
extract2Packet16h(a, low, high); \
return combine2Packet16h(func(low), func(high)); \
}
_EIGEN_GENERATE_FP16_MATH_FUNCTION
(
psin
)
_EIGEN_GENERATE_FP16_MATH_FUNCTION
(
pcos
)
_EIGEN_GENERATE_FP16_MATH_FUNCTION
(
plog
)
_EIGEN_GENERATE_FP16_MATH_FUNCTION
(
plog2
)
_EIGEN_GENERATE_FP16_MATH_FUNCTION
(
plog1p
)
_EIGEN_GENERATE_FP16_MATH_FUNCTION
(
pexp
)
_EIGEN_GENERATE_FP16_MATH_FUNCTION
(
pexpm1
)
_EIGEN_GENERATE_FP16_MATH_FUNCTION
(
pexp2
)
_EIGEN_GENERATE_FP16_MATH_FUNCTION
(
ptanh
)
#undef _EIGEN_GENERATE_FP16_MATH_FUNCTION
// pfrexp
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pfrexp
<
Packet32h
>
(
const
Packet32h
&
a
,
Packet32h
&
exponent
)
{
return
pfrexp_generic
(
a
,
exponent
);
}
// pldexp
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pldexp
<
Packet32h
>
(
const
Packet32h
&
a
,
const
Packet32h
&
exponent
)
{
return
pldexp_generic
(
a
,
exponent
);
}
}
// end namespace internal
}
// end namespace Eigen
#endif // EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
\ No newline at end of file
eigen-master/Eigen/src/Core/arch/AVX512/PacketMath.h
0 → 100644
View file @
266d4fd9
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2016 Benoit Steiner (benoit.steiner.goog@gmail.com)
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_PACKET_MATH_AVX512_H
#define EIGEN_PACKET_MATH_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace
Eigen
{
namespace
internal
{
#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
#endif
#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
#endif
#ifdef EIGEN_VECTORIZE_FMA
#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
#endif
typedef
__m512
Packet16f
;
typedef
__m512i
Packet16i
;
typedef
__m512d
Packet8d
;
typedef
eigen_packet_wrapper
<
__m512i
,
1
>
Packet8l
;
#ifndef EIGEN_VECTORIZE_AVX512FP16
typedef
eigen_packet_wrapper
<
__m256i
,
1
>
Packet16h
;
#endif
typedef
eigen_packet_wrapper
<
__m256i
,
2
>
Packet16bf
;
typedef
eigen_packet_wrapper
<
__m512i
,
6
>
Packet32s
;
typedef
eigen_packet_wrapper
<
__m256i
,
6
>
Packet16s
;
typedef
eigen_packet_wrapper
<
__m128i
,
6
>
Packet8s
;
template
<
>
struct
is_arithmetic
<
__m512
>
{
enum
{
value
=
true
};
};
template
<
>
struct
is_arithmetic
<
__m512i
>
{
enum
{
value
=
true
};
};
template
<
>
struct
is_arithmetic
<
__m512d
>
{
enum
{
value
=
true
};
};
template
<
>
struct
is_arithmetic
<
Packet8l
>
{
enum
{
value
=
true
};
};
#ifndef EIGEN_VECTORIZE_AVX512FP16
template
<
>
struct
is_arithmetic
<
Packet16h
>
{
enum
{
value
=
true
};
};
template
<
>
struct
packet_traits
<
half
>
:
default_packet_traits
{
typedef
Packet16h
type
;
// There is no half-size packet for Packet16h.
typedef
Packet16h
half
;
enum
{
Vectorizable
=
1
,
AlignedOnScalar
=
1
,
size
=
16
,
HasCmp
=
1
,
HasAdd
=
1
,
HasSub
=
1
,
HasMul
=
1
,
HasDiv
=
1
,
HasNegate
=
1
,
HasAbs
=
1
,
HasAbs2
=
0
,
HasMin
=
1
,
HasMax
=
1
,
HasConj
=
1
,
HasSetLinear
=
0
,
HasSqrt
=
1
,
HasRsqrt
=
1
,
HasLog
=
1
,
HasLog1p
=
1
,
HasExp
=
1
,
HasExpm1
=
1
,
HasBessel
=
1
,
HasNdtri
=
1
,
HasSin
=
EIGEN_FAST_MATH
,
HasCos
=
EIGEN_FAST_MATH
,
HasTanh
=
EIGEN_FAST_MATH
,
HasErf
=
EIGEN_FAST_MATH
,
HasBlend
=
0
};
};
#endif
template
<
>
struct
packet_traits
<
float
>
:
default_packet_traits
{
typedef
Packet16f
type
;
typedef
Packet8f
half
;
enum
{
Vectorizable
=
1
,
AlignedOnScalar
=
1
,
size
=
16
,
HasAbs
=
1
,
HasMin
=
1
,
HasMax
=
1
,
HasConj
=
1
,
HasBlend
=
1
,
HasSin
=
EIGEN_FAST_MATH
,
HasCos
=
EIGEN_FAST_MATH
,
HasACos
=
1
,
HasASin
=
1
,
HasATan
=
1
,
HasATanh
=
1
,
HasSqrt
=
1
,
HasRsqrt
=
1
,
HasCbrt
=
1
,
HasLog
=
1
,
HasLog1p
=
1
,
HasExpm1
=
1
,
HasNdtri
=
1
,
HasBessel
=
1
,
HasExp
=
1
,
HasReciprocal
=
EIGEN_FAST_MATH
,
HasTanh
=
EIGEN_FAST_MATH
,
HasErf
=
EIGEN_FAST_MATH
,
HasErfc
=
EIGEN_FAST_MATH
,
HasCmp
=
1
,
HasDiv
=
1
};
};
template
<
>
struct
packet_traits
<
double
>
:
default_packet_traits
{
typedef
Packet8d
type
;
typedef
Packet4d
half
;
enum
{
Vectorizable
=
1
,
AlignedOnScalar
=
1
,
size
=
8
,
HasBlend
=
1
,
HasSqrt
=
1
,
HasRsqrt
=
1
,
HasCbrt
=
1
,
HasSin
=
EIGEN_FAST_MATH
,
HasCos
=
EIGEN_FAST_MATH
,
HasLog
=
1
,
HasExp
=
1
,
HasATan
=
1
,
HasTanh
=
EIGEN_FAST_MATH
,
HasErf
=
EIGEN_FAST_MATH
,
HasErfc
=
EIGEN_FAST_MATH
,
HasATanh
=
1
,
HasCmp
=
1
,
HasDiv
=
1
};
};
template
<
>
struct
packet_traits
<
int
>
:
default_packet_traits
{
typedef
Packet16i
type
;
typedef
Packet8i
half
;
enum
{
Vectorizable
=
1
,
AlignedOnScalar
=
1
,
HasBlend
=
0
,
HasCmp
=
1
,
HasDiv
=
1
,
size
=
16
};
};
template
<
>
struct
packet_traits
<
int64_t
>
:
default_packet_traits
{
typedef
Packet8l
type
;
typedef
Packet4l
half
;
enum
{
Vectorizable
=
1
,
AlignedOnScalar
=
1
,
HasCmp
=
1
,
size
=
8
};
};
template
<
>
struct
unpacket_traits
<
Packet16f
>
{
typedef
float
type
;
typedef
Packet8f
half
;
typedef
Packet16i
integer_packet
;
typedef
uint16_t
mask_t
;
enum
{
size
=
16
,
alignment
=
Aligned64
,
vectorizable
=
true
,
masked_load_available
=
true
,
masked_store_available
=
true
,
masked_fpops_available
=
true
};
};
template
<
>
struct
unpacket_traits
<
Packet8d
>
{
typedef
double
type
;
typedef
Packet4d
half
;
typedef
Packet8l
integer_packet
;
typedef
uint8_t
mask_t
;
enum
{
size
=
8
,
alignment
=
Aligned64
,
vectorizable
=
true
,
masked_load_available
=
true
,
masked_store_available
=
true
,
masked_fpops_available
=
true
};
};
template
<
>
struct
unpacket_traits
<
Packet16i
>
{
typedef
int
type
;
typedef
Packet8i
half
;
enum
{
size
=
16
,
alignment
=
Aligned64
,
vectorizable
=
true
,
masked_load_available
=
false
,
masked_store_available
=
false
};
};
template
<
>
struct
unpacket_traits
<
Packet8l
>
{
typedef
int64_t
type
;
typedef
Packet4l
half
;
enum
{
size
=
8
,
alignment
=
Aligned64
,
vectorizable
=
true
,
masked_load_available
=
false
,
masked_store_available
=
false
};
};
#ifndef EIGEN_VECTORIZE_AVX512FP16
template
<
>
struct
unpacket_traits
<
Packet16h
>
{
typedef
Eigen
::
half
type
;
typedef
Packet8h
half
;
enum
{
size
=
16
,
alignment
=
Aligned32
,
vectorizable
=
true
,
masked_load_available
=
false
,
masked_store_available
=
false
};
};
#endif
template
<
>
struct
unpacket_traits
<
Packet32s
>
{
typedef
numext
::
int16_t
type
;
typedef
Packet16s
half
;
enum
{
size
=
32
,
alignment
=
Aligned64
,
vectorizable
=
false
,
};
};
template
<
>
struct
unpacket_traits
<
Packet16s
>
{
typedef
numext
::
int16_t
type
;
typedef
Packet8s
half
;
enum
{
size
=
16
,
alignment
=
Aligned32
,
vectorizable
=
false
,
};
};
template
<
>
struct
unpacket_traits
<
Packet8s
>
{
typedef
numext
::
int16_t
type
;
typedef
Packet8s
half
;
enum
{
size
=
8
,
alignment
=
Aligned16
,
vectorizable
=
false
,
};
};
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pset1
<
Packet16f
>
(
const
float
&
from
)
{
return
_mm512_set1_ps
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pset1
<
Packet8d
>
(
const
double
&
from
)
{
return
_mm512_set1_pd
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pset1
<
Packet16i
>
(
const
int
&
from
)
{
return
_mm512_set1_epi32
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pset1
<
Packet8l
>
(
const
int64_t
&
from
)
{
return
_mm512_set1_epi64
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pset1frombits
<
Packet16f
>
(
unsigned
int
from
)
{
return
_mm512_castsi512_ps
(
_mm512_set1_epi32
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pset1frombits
<
Packet8d
>
(
const
numext
::
uint64_t
from
)
{
return
_mm512_castsi512_pd
(
_mm512_set1_epi64
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pzero
(
const
Packet16f
&
/*a*/
)
{
return
_mm512_setzero_ps
();
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pzero
(
const
Packet8d
&
/*a*/
)
{
return
_mm512_setzero_pd
();
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pzero
(
const
Packet16i
&
/*a*/
)
{
return
_mm512_setzero_si512
();
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pzero
(
const
Packet8l
&
/*a*/
)
{
return
_mm512_setzero_si512
();
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
peven_mask
(
const
Packet16f
&
/*a*/
)
{
return
_mm512_castsi512_ps
(
_mm512_set_epi32
(
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
peven_mask
(
const
Packet16i
&
/*a*/
)
{
return
_mm512_set_epi32
(
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
,
0
,
-
1
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
peven_mask
(
const
Packet8d
&
/*a*/
)
{
return
_mm512_castsi512_pd
(
_mm512_set_epi32
(
0
,
0
,
-
1
,
-
1
,
0
,
0
,
-
1
,
-
1
,
0
,
0
,
-
1
,
-
1
,
0
,
0
,
-
1
,
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
peven_mask
(
const
Packet8l
&
/*a*/
)
{
return
_mm512_set_epi32
(
0
,
0
,
-
1
,
-
1
,
0
,
0
,
-
1
,
-
1
,
0
,
0
,
-
1
,
-
1
,
0
,
0
,
-
1
,
-
1
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pload1
<
Packet16f
>
(
const
float
*
from
)
{
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
// Inline asm here helps reduce some register spilling in TRSM kernels.
// See note in unrolls::gemm::microKernel in TrsmKernel.h
Packet16f
ret
;
__asm__
(
"vbroadcastss %[mem], %[dst]"
:
[
dst
]
"=v"
(
ret
)
:
[
mem
]
"m"
(
*
from
));
return
ret
;
#else
return
_mm512_broadcastss_ps
(
_mm_load_ps1
(
from
));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pload1
<
Packet8d
>
(
const
double
*
from
)
{
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
Packet8d
ret
;
__asm__
(
"vbroadcastsd %[mem], %[dst]"
:
[
dst
]
"=v"
(
ret
)
:
[
mem
]
"m"
(
*
from
));
return
ret
;
#else
return
_mm512_set1_pd
(
*
from
);
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
plset
<
Packet16f
>
(
const
float
&
a
)
{
return
_mm512_add_ps
(
_mm512_set1_ps
(
a
),
_mm512_set_ps
(
15.0
f
,
14.0
f
,
13.0
f
,
12.0
f
,
11.0
f
,
10.0
f
,
9.0
f
,
8.0
f
,
7.0
f
,
6.0
f
,
5.0
f
,
4.0
f
,
3.0
f
,
2.0
f
,
1.0
f
,
0.0
f
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
plset
<
Packet8d
>
(
const
double
&
a
)
{
return
_mm512_add_pd
(
_mm512_set1_pd
(
a
),
_mm512_set_pd
(
7.0
,
6.0
,
5.0
,
4.0
,
3.0
,
2.0
,
1.0
,
0.0
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
plset
<
Packet16i
>
(
const
int
&
a
)
{
return
_mm512_add_epi32
(
_mm512_set1_epi32
(
a
),
_mm512_set_epi32
(
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
plset
<
Packet8l
>
(
const
int64_t
&
a
)
{
return
_mm512_add_epi64
(
_mm512_set1_epi64
(
a
),
_mm512_set_epi64
(
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
padd
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
return
_mm512_add_ps
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
padd
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
_mm512_add_pd
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
padd
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
return
_mm512_add_epi32
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
padd
<
Packet8l
>
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
return
_mm512_add_epi64
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
padd
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
,
uint16_t
umask
)
{
__mmask16
mask
=
static_cast
<
__mmask16
>
(
umask
);
return
_mm512_maskz_add_ps
(
mask
,
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
padd
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
,
uint8_t
umask
)
{
__mmask8
mask
=
static_cast
<
__mmask8
>
(
umask
);
return
_mm512_maskz_add_pd
(
mask
,
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
psub
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
return
_mm512_sub_ps
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
psub
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
_mm512_sub_pd
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
psub
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
return
_mm512_sub_epi32
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
psub
<
Packet8l
>
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
return
_mm512_sub_epi64
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pnegate
(
const
Packet16f
&
a
)
{
// NOTE: MSVC seems to struggle with _mm512_set1_epi32, leading to random results.
// The intel docs give it a relatively high latency as well, so we're probably
// better off with using _mm512_set_epi32 directly anyways.
const
__m512i
mask
=
_mm512_set_epi32
(
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
,
0x80000000
);
return
_mm512_castsi512_ps
(
_mm512_xor_epi32
(
_mm512_castps_si512
(
a
),
mask
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pnegate
(
const
Packet8d
&
a
)
{
const
__m512i
mask
=
_mm512_set_epi64
(
0x8000000000000000ULL
,
0x8000000000000000ULL
,
0x8000000000000000ULL
,
0x8000000000000000ULL
,
0x8000000000000000ULL
,
0x8000000000000000ULL
,
0x8000000000000000ULL
,
0x8000000000000000ULL
);
return
_mm512_castsi512_pd
(
_mm512_xor_epi64
(
_mm512_castpd_si512
(
a
),
mask
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pnegate
(
const
Packet16i
&
a
)
{
return
_mm512_sub_epi32
(
_mm512_setzero_si512
(),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pnegate
(
const
Packet8l
&
a
)
{
return
_mm512_sub_epi64
(
_mm512_setzero_si512
(),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pconj
(
const
Packet16f
&
a
)
{
return
a
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pconj
(
const
Packet8d
&
a
)
{
return
a
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pconj
(
const
Packet16i
&
a
)
{
return
a
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pconj
(
const
Packet8l
&
a
)
{
return
a
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pmul
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
return
_mm512_mul_ps
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pmul
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
_mm512_mul_pd
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pmul
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
return
_mm512_mullo_epi32
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pmul
<
Packet8l
>
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_mullo_epi64
(
a
,
b
);
#else
return
_mm512_mullox_epi64
(
a
,
b
);
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pdiv
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
return
_mm512_div_ps
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pdiv
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
_mm512_div_pd
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pdiv
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
Packet8i
q_lo
=
pdiv
<
Packet8i
>
(
_mm512_extracti64x4_epi64
(
a
,
0
),
_mm512_extracti64x4_epi64
(
b
,
0
));
Packet8i
q_hi
=
pdiv
<
Packet8i
>
(
_mm512_extracti64x4_epi64
(
a
,
1
),
_mm512_extracti64x4_epi64
(
b
,
1
));
return
_mm512_inserti64x4
(
_mm512_castsi256_si512
(
q_lo
),
q_hi
,
1
);
}
#ifdef EIGEN_VECTORIZE_FMA
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pmadd
(
const
Packet16f
&
a
,
const
Packet16f
&
b
,
const
Packet16f
&
c
)
{
return
_mm512_fmadd_ps
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pmadd
(
const
Packet8d
&
a
,
const
Packet8d
&
b
,
const
Packet8d
&
c
)
{
return
_mm512_fmadd_pd
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pmsub
(
const
Packet16f
&
a
,
const
Packet16f
&
b
,
const
Packet16f
&
c
)
{
return
_mm512_fmsub_ps
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pmsub
(
const
Packet8d
&
a
,
const
Packet8d
&
b
,
const
Packet8d
&
c
)
{
return
_mm512_fmsub_pd
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pnmadd
(
const
Packet16f
&
a
,
const
Packet16f
&
b
,
const
Packet16f
&
c
)
{
return
_mm512_fnmadd_ps
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pnmadd
(
const
Packet8d
&
a
,
const
Packet8d
&
b
,
const
Packet8d
&
c
)
{
return
_mm512_fnmadd_pd
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pnmsub
(
const
Packet16f
&
a
,
const
Packet16f
&
b
,
const
Packet16f
&
c
)
{
return
_mm512_fnmsub_ps
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pnmsub
(
const
Packet8d
&
a
,
const
Packet8d
&
b
,
const
Packet8d
&
c
)
{
return
_mm512_fnmsub_pd
(
a
,
b
,
c
);
}
#endif
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet16f
pselect
(
const
Packet16f
&
mask
,
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
__mmask16
mask16
=
_mm512_cmpeq_epi32_mask
(
_mm512_castps_si512
(
mask
),
_mm512_setzero_epi32
());
return
_mm512_mask_blend_ps
(
mask16
,
a
,
b
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet16i
pselect
(
const
Packet16i
&
mask
,
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
__mmask16
mask16
=
_mm512_cmpeq_epi32_mask
(
mask
,
_mm512_setzero_epi32
());
return
_mm512_mask_blend_epi32
(
mask16
,
a
,
b
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet8l
pselect
(
const
Packet8l
&
mask
,
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
__mmask8
mask8
=
_mm512_cmpeq_epi64_mask
(
mask
,
_mm512_setzero_si512
());
return
_mm512_mask_blend_epi64
(
mask8
,
a
,
b
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet8d
pselect
(
const
Packet8d
&
mask
,
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
__mmask8
mask8
=
_mm512_cmp_epi64_mask
(
_mm512_castpd_si512
(
mask
),
_mm512_setzero_epi32
(),
_MM_CMPINT_EQ
);
return
_mm512_mask_blend_pd
(
mask8
,
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pmin
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
// Arguments are reversed to match NaN propagation behavior of std::min.
return
_mm512_min_ps
(
b
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pmin
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
// Arguments are reversed to match NaN propagation behavior of std::min.
return
_mm512_min_pd
(
b
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pmin
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
return
_mm512_min_epi32
(
b
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pmin
<
Packet8l
>
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
return
_mm512_min_epi64
(
b
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pmax
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
// Arguments are reversed to match NaN propagation behavior of std::max.
return
_mm512_max_ps
(
b
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pmax
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
// Arguments are reversed to match NaN propagation behavior of std::max.
return
_mm512_max_pd
(
b
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pmax
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
return
_mm512_max_epi32
(
b
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pmax
<
Packet8l
>
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
return
_mm512_max_epi64
(
b
,
a
);
}
// Add specializations for min/max with prescribed NaN propagation.
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pmin
<
PropagateNumbers
,
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
return
pminmax_propagate_numbers
(
a
,
b
,
pmin
<
Packet16f
>
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pmin
<
PropagateNumbers
,
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
pminmax_propagate_numbers
(
a
,
b
,
pmin
<
Packet8d
>
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pmax
<
PropagateNumbers
,
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
return
pminmax_propagate_numbers
(
a
,
b
,
pmax
<
Packet16f
>
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pmax
<
PropagateNumbers
,
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
pminmax_propagate_numbers
(
a
,
b
,
pmax
<
Packet8d
>
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pmin
<
PropagateNaN
,
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
return
pminmax_propagate_nan
(
a
,
b
,
pmin
<
Packet16f
>
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pmin
<
PropagateNaN
,
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
pminmax_propagate_nan
(
a
,
b
,
pmin
<
Packet8d
>
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pmax
<
PropagateNaN
,
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
return
pminmax_propagate_nan
(
a
,
b
,
pmax
<
Packet16f
>
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pmax
<
PropagateNaN
,
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
pminmax_propagate_nan
(
a
,
b
,
pmax
<
Packet8d
>
);
}
#ifdef EIGEN_VECTORIZE_AVX512DQ
template
<
int
I_
>
EIGEN_STRONG_INLINE
Packet8f
extract256
(
Packet16f
x
)
{
return
_mm512_extractf32x8_ps
(
x
,
I_
);
}
template
<
int
I_
>
EIGEN_STRONG_INLINE
Packet2d
extract128
(
Packet8d
x
)
{
return
_mm512_extractf64x2_pd
(
x
,
I_
);
}
EIGEN_STRONG_INLINE
Packet16f
cat256
(
Packet8f
a
,
Packet8f
b
)
{
return
_mm512_insertf32x8
(
_mm512_castps256_ps512
(
a
),
b
,
1
);
}
EIGEN_STRONG_INLINE
Packet16i
cat256i
(
Packet8i
a
,
Packet8i
b
)
{
return
_mm512_inserti32x8
(
_mm512_castsi256_si512
(
a
),
b
,
1
);
}
#else
// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
template
<
int
I_
>
EIGEN_STRONG_INLINE
Packet8f
extract256
(
Packet16f
x
)
{
return
_mm256_castsi256_ps
(
_mm512_extracti64x4_epi64
(
_mm512_castps_si512
(
x
),
I_
));
}
// AVX512F does not define _mm512_extractf64x2_pd to extract _m128 from _m512
template
<
int
I_
>
EIGEN_STRONG_INLINE
Packet2d
extract128
(
Packet8d
x
)
{
return
_mm_castsi128_pd
(
_mm512_extracti32x4_epi32
(
_mm512_castpd_si512
(
x
),
I_
));
}
EIGEN_STRONG_INLINE
Packet16f
cat256
(
Packet8f
a
,
Packet8f
b
)
{
return
_mm512_castsi512_ps
(
_mm512_inserti64x4
(
_mm512_castsi256_si512
(
_mm256_castps_si256
(
a
)),
_mm256_castps_si256
(
b
),
1
));
}
EIGEN_STRONG_INLINE
Packet16i
cat256i
(
Packet8i
a
,
Packet8i
b
)
{
return
_mm512_inserti64x4
(
_mm512_castsi256_si512
(
a
),
b
,
1
);
}
#endif
// Helper function for bit packing snippet of low precision comparison.
// It packs the flags from 32x16 to 16x16.
EIGEN_STRONG_INLINE
__m256i
Pack32To16
(
Packet16f
rf
)
{
// Split data into small pieces and handle with AVX instructions
// to guarantee internal order of vector.
// Operation:
// dst[15:0] := Saturate16(rf[31:0])
// dst[31:16] := Saturate16(rf[63:32])
// ...
// dst[255:240] := Saturate16(rf[255:224])
__m256i
lo
=
_mm256_castps_si256
(
extract256
<
0
>
(
rf
));
__m256i
hi
=
_mm256_castps_si256
(
extract256
<
1
>
(
rf
));
__m128i
result_lo
=
_mm_packs_epi32
(
_mm256_extractf128_si256
(
lo
,
0
),
_mm256_extractf128_si256
(
lo
,
1
));
__m128i
result_hi
=
_mm_packs_epi32
(
_mm256_extractf128_si256
(
hi
,
0
),
_mm256_extractf128_si256
(
hi
,
1
));
return
_mm256_insertf128_si256
(
_mm256_castsi128_si256
(
result_lo
),
result_hi
,
1
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pisnan
(
const
Packet16f
&
a
)
{
__mmask16
mask
=
_mm512_cmp_ps_mask
(
a
,
a
,
_CMP_UNORD_Q
);
return
_mm512_castsi512_ps
(
_mm512_maskz_set1_epi32
(
mask
,
int32_t
(
-
1
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcmp_eq
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
__mmask16
mask
=
_mm512_cmp_ps_mask
(
a
,
b
,
_CMP_EQ_OQ
);
return
_mm512_castsi512_ps
(
_mm512_mask_set1_epi32
(
_mm512_setzero_epi32
(),
mask
,
int32_t
(
-
1
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcmp_le
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
__mmask16
mask
=
_mm512_cmp_ps_mask
(
a
,
b
,
_CMP_LE_OQ
);
return
_mm512_castsi512_ps
(
_mm512_mask_set1_epi32
(
_mm512_setzero_epi32
(),
mask
,
int32_t
(
-
1
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcmp_lt
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
__mmask16
mask
=
_mm512_cmp_ps_mask
(
a
,
b
,
_CMP_LT_OQ
);
return
_mm512_castsi512_ps
(
_mm512_mask_set1_epi32
(
_mm512_setzero_epi32
(),
mask
,
int32_t
(
-
1
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcmp_lt_or_nan
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
__mmask16
mask
=
_mm512_cmp_ps_mask
(
a
,
b
,
_CMP_NGE_UQ
);
return
_mm512_castsi512_ps
(
_mm512_mask_set1_epi32
(
_mm512_setzero_epi32
(),
mask
,
int32_t
(
-
1
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pcmp_eq
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
__mmask16
mask
=
_mm512_cmp_epi32_mask
(
a
,
b
,
_MM_CMPINT_EQ
);
return
_mm512_mask_set1_epi32
(
_mm512_setzero_epi32
(),
mask
,
int32_t
(
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pcmp_le
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
__mmask16
mask
=
_mm512_cmp_epi32_mask
(
a
,
b
,
_MM_CMPINT_LE
);
return
_mm512_mask_set1_epi32
(
_mm512_setzero_epi32
(),
mask
,
int32_t
(
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pcmp_lt
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
__mmask16
mask
=
_mm512_cmp_epi32_mask
(
a
,
b
,
_MM_CMPINT_LT
);
return
_mm512_mask_set1_epi32
(
_mm512_setzero_epi32
(),
mask
,
int32_t
(
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pcmp_eq
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
__mmask8
mask
=
_mm512_cmp_epi64_mask
(
a
,
b
,
_MM_CMPINT_EQ
);
return
_mm512_mask_set1_epi64
(
_mm512_setzero_si512
(),
mask
,
int64_t
(
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pcmp_le
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
__mmask8
mask
=
_mm512_cmp_epi64_mask
(
a
,
b
,
_MM_CMPINT_LE
);
return
_mm512_mask_set1_epi64
(
_mm512_setzero_si512
(),
mask
,
int64_t
(
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pcmp_lt
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
__mmask8
mask
=
_mm512_cmp_epi64_mask
(
a
,
b
,
_MM_CMPINT_LT
);
return
_mm512_mask_set1_epi64
(
_mm512_setzero_si512
(),
mask
,
int64_t
(
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pcmp_eq
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
__mmask8
mask
=
_mm512_cmp_pd_mask
(
a
,
b
,
_CMP_EQ_OQ
);
return
_mm512_castsi512_pd
(
_mm512_mask_set1_epi64
(
_mm512_setzero_epi32
(),
mask
,
0xffffffffffffffffu
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pcmp_le
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
__mmask8
mask
=
_mm512_cmp_pd_mask
(
a
,
b
,
_CMP_LE_OQ
);
return
_mm512_castsi512_pd
(
_mm512_mask_set1_epi64
(
_mm512_setzero_epi32
(),
mask
,
0xffffffffffffffffu
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pcmp_lt
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
__mmask8
mask
=
_mm512_cmp_pd_mask
(
a
,
b
,
_CMP_LT_OQ
);
return
_mm512_castsi512_pd
(
_mm512_mask_set1_epi64
(
_mm512_setzero_epi32
(),
mask
,
0xffffffffffffffffu
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pcmp_lt_or_nan
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
__mmask8
mask
=
_mm512_cmp_pd_mask
(
a
,
b
,
_CMP_NGE_UQ
);
return
_mm512_castsi512_pd
(
_mm512_mask_set1_epi64
(
_mm512_setzero_epi32
(),
mask
,
0xffffffffffffffffu
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
print
<
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_roundscale_ps
(
a
,
_MM_FROUND_CUR_DIRECTION
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
print
<
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_roundscale_pd
(
a
,
_MM_FROUND_CUR_DIRECTION
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pceil
<
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_roundscale_ps
(
a
,
_MM_FROUND_TO_POS_INF
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pceil
<
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_roundscale_pd
(
a
,
_MM_FROUND_TO_POS_INF
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pfloor
<
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_roundscale_ps
(
a
,
_MM_FROUND_TO_NEG_INF
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pfloor
<
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_roundscale_pd
(
a
,
_MM_FROUND_TO_NEG_INF
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
ptrunc
<
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_roundscale_ps
(
a
,
_MM_FROUND_TO_ZERO
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
ptrunc
<
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_roundscale_pd
(
a
,
_MM_FROUND_TO_ZERO
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
ptrue
<
Packet16i
>
(
const
Packet16i
&
/*a*/
)
{
return
_mm512_set1_epi32
(
int32_t
(
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
ptrue
<
Packet8l
>
(
const
Packet8l
&
/*a*/
)
{
return
_mm512_set1_epi64
(
int64_t
(
-
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
ptrue
<
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_castsi512_ps
(
ptrue
<
Packet16i
>
(
_mm512_castps_si512
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
ptrue
<
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_castsi512_pd
(
ptrue
<
Packet16i
>
(
_mm512_castpd_si512
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pand
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
return
_mm512_and_si512
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pand
<
Packet8l
>
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
return
_mm512_and_si512
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pand
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_and_ps
(
a
,
b
);
#else
return
_mm512_castsi512_ps
(
pand
(
_mm512_castps_si512
(
a
),
_mm512_castps_si512
(
b
)));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pand
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_and_pd
(
a
,
b
);
#else
Packet8d
res
=
_mm512_undefined_pd
();
Packet4d
lane0_a
=
_mm512_extractf64x4_pd
(
a
,
0
);
Packet4d
lane0_b
=
_mm512_extractf64x4_pd
(
b
,
0
);
res
=
_mm512_insertf64x4
(
res
,
_mm256_and_pd
(
lane0_a
,
lane0_b
),
0
);
Packet4d
lane1_a
=
_mm512_extractf64x4_pd
(
a
,
1
);
Packet4d
lane1_b
=
_mm512_extractf64x4_pd
(
b
,
1
);
return
_mm512_insertf64x4
(
res
,
_mm256_and_pd
(
lane1_a
,
lane1_b
),
1
);
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
por
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
return
_mm512_or_si512
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
por
<
Packet8l
>
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
return
_mm512_or_si512
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
por
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_or_ps
(
a
,
b
);
#else
return
_mm512_castsi512_ps
(
por
(
_mm512_castps_si512
(
a
),
_mm512_castps_si512
(
b
)));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
por
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_or_pd
(
a
,
b
);
#else
return
_mm512_castsi512_pd
(
por
(
_mm512_castpd_si512
(
a
),
_mm512_castpd_si512
(
b
)));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pxor
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
return
_mm512_xor_si512
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pxor
<
Packet8l
>
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
return
_mm512_xor_si512
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pxor
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_xor_ps
(
a
,
b
);
#else
return
_mm512_castsi512_ps
(
pxor
(
_mm512_castps_si512
(
a
),
_mm512_castps_si512
(
b
)));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pxor
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_xor_pd
(
a
,
b
);
#else
return
_mm512_castsi512_pd
(
pxor
(
_mm512_castpd_si512
(
a
),
_mm512_castpd_si512
(
b
)));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pandnot
<
Packet16i
>
(
const
Packet16i
&
a
,
const
Packet16i
&
b
)
{
return
_mm512_andnot_si512
(
b
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pandnot
<
Packet8l
>
(
const
Packet8l
&
a
,
const
Packet8l
&
b
)
{
return
_mm512_andnot_si512
(
b
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pandnot
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_andnot_ps
(
b
,
a
);
#else
return
_mm512_castsi512_ps
(
pandnot
(
_mm512_castps_si512
(
a
),
_mm512_castps_si512
(
b
)));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pandnot
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_andnot_pd
(
b
,
a
);
#else
return
_mm512_castsi512_pd
(
pandnot
(
_mm512_castpd_si512
(
a
),
_mm512_castpd_si512
(
b
)));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pround
<
Packet16f
>
(
const
Packet16f
&
a
)
{
// Work-around for default std::round rounding mode.
const
Packet16f
mask
=
pset1frombits
<
Packet16f
>
(
static_cast
<
numext
::
uint32_t
>
(
0x80000000u
));
const
Packet16f
prev0dot5
=
pset1frombits
<
Packet16f
>
(
static_cast
<
numext
::
uint32_t
>
(
0x3EFFFFFFu
));
return
_mm512_roundscale_ps
(
padd
(
por
(
pand
(
a
,
mask
),
prev0dot5
),
a
),
_MM_FROUND_TO_ZERO
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pround
<
Packet8d
>
(
const
Packet8d
&
a
)
{
// Work-around for default std::round rounding mode.
const
Packet8d
mask
=
pset1frombits
<
Packet8d
>
(
static_cast
<
numext
::
uint64_t
>
(
0x8000000000000000ull
));
const
Packet8d
prev0dot5
=
pset1frombits
<
Packet8d
>
(
static_cast
<
numext
::
uint64_t
>
(
0x3FDFFFFFFFFFFFFFull
));
return
_mm512_roundscale_pd
(
padd
(
por
(
pand
(
a
,
mask
),
prev0dot5
),
a
),
_MM_FROUND_TO_ZERO
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet16i
parithmetic_shift_right
(
Packet16i
a
)
{
return
_mm512_srai_epi32
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet16i
plogical_shift_right
(
Packet16i
a
)
{
return
_mm512_srli_epi32
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet16i
plogical_shift_left
(
Packet16i
a
)
{
return
_mm512_slli_epi32
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet8l
parithmetic_shift_right
(
Packet8l
a
)
{
return
_mm512_srai_epi64
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet8l
plogical_shift_right
(
Packet8l
a
)
{
return
_mm512_srli_epi64
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet8l
plogical_shift_left
(
Packet8l
a
)
{
return
_mm512_slli_epi64
(
a
,
N
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pload
<
Packet16f
>
(
const
float
*
from
)
{
EIGEN_DEBUG_ALIGNED_LOAD
return
_mm512_load_ps
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pload
<
Packet8d
>
(
const
double
*
from
)
{
EIGEN_DEBUG_ALIGNED_LOAD
return
_mm512_load_pd
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pload
<
Packet16i
>
(
const
int
*
from
)
{
EIGEN_DEBUG_ALIGNED_LOAD
return
_mm512_load_epi64
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pload
<
Packet8l
>
(
const
int64_t
*
from
)
{
EIGEN_DEBUG_ALIGNED_LOAD
return
_mm512_load_epi64
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
ploadu
<
Packet16f
>
(
const
float
*
from
)
{
EIGEN_DEBUG_UNALIGNED_LOAD
return
_mm512_loadu_ps
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
ploadu
<
Packet8d
>
(
const
double
*
from
)
{
EIGEN_DEBUG_UNALIGNED_LOAD
return
_mm512_loadu_pd
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
ploadu
<
Packet16i
>
(
const
int
*
from
)
{
EIGEN_DEBUG_UNALIGNED_LOAD
return
_mm512_loadu_epi32
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
ploadu
<
Packet8l
>
(
const
int64_t
*
from
)
{
EIGEN_DEBUG_UNALIGNED_LOAD
return
_mm512_loadu_epi64
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
ploadu
<
Packet16f
>
(
const
float
*
from
,
uint16_t
umask
)
{
__mmask16
mask
=
static_cast
<
__mmask16
>
(
umask
);
EIGEN_DEBUG_UNALIGNED_LOAD
return
_mm512_maskz_loadu_ps
(
mask
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
ploadu
<
Packet8d
>
(
const
double
*
from
,
uint8_t
umask
)
{
__mmask8
mask
=
static_cast
<
__mmask8
>
(
umask
);
EIGEN_DEBUG_UNALIGNED_LOAD
return
_mm512_maskz_loadu_pd
(
mask
,
from
);
}
// Loads 8 floats from memory a returns the packet
// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
ploaddup
<
Packet16f
>
(
const
float
*
from
)
{
// an unaligned load is required here as there is no requirement
// on the alignment of input pointer 'from'
__m256i
low_half
=
_mm256_castps_si256
(
_mm256_loadu_ps
(
from
));
__m512
even_elements
=
_mm512_castsi512_ps
(
_mm512_cvtepu32_epi64
(
low_half
));
__m512
pairs
=
_mm512_permute_ps
(
even_elements
,
_MM_SHUFFLE
(
2
,
2
,
0
,
0
));
return
pairs
;
}
// Loads 4 doubles from memory a returns the packet {a0, a0, a1, a1, a2, a2, a3,
// a3}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
ploaddup
<
Packet8d
>
(
const
double
*
from
)
{
Packet8d
tmp
=
_mm512_castpd256_pd512
(
ploadu
<
Packet4d
>
(
from
));
const
Packet8l
scatter_mask
=
_mm512_set_epi64
(
3
,
3
,
2
,
2
,
1
,
1
,
0
,
0
);
return
_mm512_permutexvar_pd
(
scatter_mask
,
tmp
);
}
// Loads 4 int64_t from memory a returns the packet {a0, a0, a1, a1, a2, a2, a3,
// a3}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
ploaddup
<
Packet8l
>
(
const
int64_t
*
from
)
{
Packet8l
tmp
=
_mm512_castsi256_si512
(
ploadu
<
Packet4l
>
(
from
));
const
Packet8l
scatter_mask
=
_mm512_set_epi64
(
3
,
3
,
2
,
2
,
1
,
1
,
0
,
0
);
return
_mm512_permutexvar_epi64
(
scatter_mask
,
tmp
);
}
// Loads 8 integers from memory and returns the packet
// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
ploaddup
<
Packet16i
>
(
const
int
*
from
)
{
__m256i
low_half
=
_mm256_load_si256
(
reinterpret_cast
<
const
__m256i
*>
(
from
));
__m512
even_elements
=
_mm512_castsi512_ps
(
_mm512_cvtepu32_epi64
(
low_half
));
__m512
pairs
=
_mm512_permute_ps
(
even_elements
,
_MM_SHUFFLE
(
2
,
2
,
0
,
0
));
return
_mm512_castps_si512
(
pairs
);
}
// Loads 4 floats from memory a returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
ploadquad
<
Packet16f
>
(
const
float
*
from
)
{
Packet16f
tmp
=
_mm512_castps128_ps512
(
ploadu
<
Packet4f
>
(
from
));
const
Packet16i
scatter_mask
=
_mm512_set_epi32
(
3
,
3
,
3
,
3
,
2
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
);
return
_mm512_permutexvar_ps
(
scatter_mask
,
tmp
);
}
// Loads 2 doubles from memory a returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
ploadquad
<
Packet8d
>
(
const
double
*
from
)
{
__m256d
lane0
=
_mm256_set1_pd
(
*
from
);
__m256d
lane1
=
_mm256_set1_pd
(
*
(
from
+
1
));
__m512d
tmp
=
_mm512_undefined_pd
();
tmp
=
_mm512_insertf64x4
(
tmp
,
lane0
,
0
);
return
_mm512_insertf64x4
(
tmp
,
lane1
,
1
);
}
// Loads 2 int64_t from memory a returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
ploadquad
<
Packet8l
>
(
const
int64_t
*
from
)
{
__m256i
lane0
=
_mm256_set1_epi64x
(
*
from
);
__m256i
lane1
=
_mm256_set1_epi64x
(
*
(
from
+
1
));
__m512i
tmp
=
_mm512_undefined_epi32
();
tmp
=
_mm512_inserti64x4
(
tmp
,
lane0
,
0
);
return
_mm512_inserti64x4
(
tmp
,
lane1
,
1
);
}
// Loads 4 integers from memory and returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
ploadquad
<
Packet16i
>
(
const
int
*
from
)
{
Packet16i
tmp
=
_mm512_castsi128_si512
(
ploadu
<
Packet4i
>
(
from
));
const
Packet16i
scatter_mask
=
_mm512_set_epi32
(
3
,
3
,
3
,
3
,
2
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
);
return
_mm512_permutexvar_epi32
(
scatter_mask
,
tmp
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
float
>
(
float
*
to
,
const
Packet16f
&
from
)
{
EIGEN_DEBUG_ALIGNED_STORE
_mm512_store_ps
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
double
>
(
double
*
to
,
const
Packet8d
&
from
)
{
EIGEN_DEBUG_ALIGNED_STORE
_mm512_store_pd
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
int
>
(
int
*
to
,
const
Packet16i
&
from
)
{
EIGEN_DEBUG_ALIGNED_STORE
_mm512_store_epi32
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
int64_t
>
(
int64_t
*
to
,
const
Packet8l
&
from
)
{
EIGEN_DEBUG_ALIGNED_STORE
_mm512_store_epi64
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
float
>
(
float
*
to
,
const
Packet16f
&
from
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm512_storeu_ps
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
double
>
(
double
*
to
,
const
Packet8d
&
from
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm512_storeu_pd
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
int
>
(
int
*
to
,
const
Packet16i
&
from
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm512_storeu_epi32
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
int64_t
>
(
int64_t
*
to
,
const
Packet8l
&
from
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm512_storeu_epi64
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
float
>
(
float
*
to
,
const
Packet16f
&
from
,
uint16_t
umask
)
{
__mmask16
mask
=
static_cast
<
__mmask16
>
(
umask
);
EIGEN_DEBUG_UNALIGNED_STORE
return
_mm512_mask_storeu_ps
(
to
,
mask
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
double
>
(
double
*
to
,
const
Packet8d
&
from
,
uint8_t
umask
)
{
__mmask8
mask
=
static_cast
<
__mmask8
>
(
umask
);
EIGEN_DEBUG_UNALIGNED_STORE
return
_mm512_mask_storeu_pd
(
to
,
mask
,
from
);
}
template
<
typename
Scalar
,
typename
Packet
>
EIGEN_DEVICE_FUNC
inline
Packet
pgather
(
const
Packet
&
src
,
const
Scalar
*
from
,
Index
stride
,
typename
unpacket_traits
<
Packet
>::
mask_t
umask
);
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet16f
pgather
<
float
,
Packet16f
>
(
const
Packet16f
&
src
,
const
float
*
from
,
Index
stride
,
uint16_t
umask
)
{
Packet16i
stride_vector
=
_mm512_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet16i
stride_multiplier
=
_mm512_set_epi32
(
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet16i
indices
=
_mm512_mullo_epi32
(
stride_vector
,
stride_multiplier
);
__mmask16
mask
=
static_cast
<
__mmask16
>
(
umask
);
return
_mm512_mask_i32gather_ps
(
src
,
mask
,
indices
,
from
,
4
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet8d
pgather
<
double
,
Packet8d
>
(
const
Packet8d
&
src
,
const
double
*
from
,
Index
stride
,
uint8_t
umask
)
{
Packet8i
stride_vector
=
_mm256_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet8i
stride_multiplier
=
_mm256_set_epi32
(
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet8i
indices
=
_mm256_mullo_epi32
(
stride_vector
,
stride_multiplier
);
__mmask8
mask
=
static_cast
<
__mmask8
>
(
umask
);
return
_mm512_mask_i32gather_pd
(
src
,
mask
,
indices
,
from
,
8
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet16f
pgather
<
float
,
Packet16f
>
(
const
float
*
from
,
Index
stride
)
{
Packet16i
stride_vector
=
_mm512_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet16i
stride_multiplier
=
_mm512_set_epi32
(
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet16i
indices
=
_mm512_mullo_epi32
(
stride_vector
,
stride_multiplier
);
return
_mm512_i32gather_ps
(
indices
,
from
,
4
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet8d
pgather
<
double
,
Packet8d
>
(
const
double
*
from
,
Index
stride
)
{
Packet8i
stride_vector
=
_mm256_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet8i
stride_multiplier
=
_mm256_set_epi32
(
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet8i
indices
=
_mm256_mullo_epi32
(
stride_vector
,
stride_multiplier
);
return
_mm512_i32gather_pd
(
indices
,
from
,
8
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet8l
pgather
<
int64_t
,
Packet8l
>
(
const
int64_t
*
from
,
Index
stride
)
{
Packet8i
stride_vector
=
_mm256_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet8i
stride_multiplier
=
_mm256_set_epi32
(
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet8i
indices
=
_mm256_mullo_epi32
(
stride_vector
,
stride_multiplier
);
return
_mm512_i32gather_epi64
(
indices
,
from
,
8
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet16i
pgather
<
int
,
Packet16i
>
(
const
int
*
from
,
Index
stride
)
{
Packet16i
stride_vector
=
_mm512_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet16i
stride_multiplier
=
_mm512_set_epi32
(
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet16i
indices
=
_mm512_mullo_epi32
(
stride_vector
,
stride_multiplier
);
return
_mm512_i32gather_epi32
(
indices
,
from
,
4
);
}
template
<
typename
Scalar
,
typename
Packet
>
EIGEN_DEVICE_FUNC
inline
void
pscatter
(
Scalar
*
to
,
const
Packet
&
from
,
Index
stride
,
typename
unpacket_traits
<
Packet
>::
mask_t
umask
);
template
<
>
EIGEN_DEVICE_FUNC
inline
void
pscatter
<
float
,
Packet16f
>
(
float
*
to
,
const
Packet16f
&
from
,
Index
stride
,
uint16_t
umask
)
{
Packet16i
stride_vector
=
_mm512_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet16i
stride_multiplier
=
_mm512_set_epi32
(
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet16i
indices
=
_mm512_mullo_epi32
(
stride_vector
,
stride_multiplier
);
__mmask16
mask
=
static_cast
<
__mmask16
>
(
umask
);
_mm512_mask_i32scatter_ps
(
to
,
mask
,
indices
,
from
,
4
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
void
pscatter
<
double
,
Packet8d
>
(
double
*
to
,
const
Packet8d
&
from
,
Index
stride
,
uint8_t
umask
)
{
Packet8i
stride_vector
=
_mm256_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet8i
stride_multiplier
=
_mm256_set_epi32
(
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet8i
indices
=
_mm256_mullo_epi32
(
stride_vector
,
stride_multiplier
);
__mmask8
mask
=
static_cast
<
__mmask8
>
(
umask
);
_mm512_mask_i32scatter_pd
(
to
,
mask
,
indices
,
from
,
8
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
void
pscatter
<
float
,
Packet16f
>
(
float
*
to
,
const
Packet16f
&
from
,
Index
stride
)
{
Packet16i
stride_vector
=
_mm512_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet16i
stride_multiplier
=
_mm512_set_epi32
(
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet16i
indices
=
_mm512_mullo_epi32
(
stride_vector
,
stride_multiplier
);
_mm512_i32scatter_ps
(
to
,
indices
,
from
,
4
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
void
pscatter
<
double
,
Packet8d
>
(
double
*
to
,
const
Packet8d
&
from
,
Index
stride
)
{
Packet8i
stride_vector
=
_mm256_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet8i
stride_multiplier
=
_mm256_set_epi32
(
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet8i
indices
=
_mm256_mullo_epi32
(
stride_vector
,
stride_multiplier
);
_mm512_i32scatter_pd
(
to
,
indices
,
from
,
8
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
void
pscatter
<
int64_t
,
Packet8l
>
(
int64_t
*
to
,
const
Packet8l
&
from
,
Index
stride
)
{
Packet8i
stride_vector
=
_mm256_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet8i
stride_multiplier
=
_mm256_set_epi32
(
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet8i
indices
=
_mm256_mullo_epi32
(
stride_vector
,
stride_multiplier
);
_mm512_i32scatter_epi64
(
to
,
indices
,
from
,
8
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
void
pscatter
<
int
,
Packet16i
>
(
int
*
to
,
const
Packet16i
&
from
,
Index
stride
)
{
Packet16i
stride_vector
=
_mm512_set1_epi32
(
convert_index
<
int
>
(
stride
));
Packet16i
stride_multiplier
=
_mm512_set_epi32
(
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
);
Packet16i
indices
=
_mm512_mullo_epi32
(
stride_vector
,
stride_multiplier
);
_mm512_i32scatter_epi32
(
to
,
indices
,
from
,
4
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore1
<
Packet16f
>
(
float
*
to
,
const
float
&
a
)
{
Packet16f
pa
=
pset1
<
Packet16f
>
(
a
);
pstore
(
to
,
pa
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore1
<
Packet8d
>
(
double
*
to
,
const
double
&
a
)
{
Packet8d
pa
=
pset1
<
Packet8d
>
(
a
);
pstore
(
to
,
pa
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore1
<
Packet16i
>
(
int
*
to
,
const
int
&
a
)
{
Packet16i
pa
=
pset1
<
Packet16i
>
(
a
);
pstore
(
to
,
pa
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore1
<
Packet8l
>
(
int64_t
*
to
,
const
int64_t
&
a
)
{
Packet8l
pa
=
pset1
<
Packet8l
>
(
a
);
pstore
(
to
,
pa
);
}
template
<
>
EIGEN_STRONG_INLINE
void
prefetch
<
float
>
(
const
float
*
addr
)
{
_mm_prefetch
((
SsePrefetchPtrType
)(
addr
),
_MM_HINT_T0
);
}
template
<
>
EIGEN_STRONG_INLINE
void
prefetch
<
double
>
(
const
double
*
addr
)
{
_mm_prefetch
((
SsePrefetchPtrType
)(
addr
),
_MM_HINT_T0
);
}
template
<
>
EIGEN_STRONG_INLINE
void
prefetch
<
int
>
(
const
int
*
addr
)
{
_mm_prefetch
((
SsePrefetchPtrType
)(
addr
),
_MM_HINT_T0
);
}
template
<
>
EIGEN_STRONG_INLINE
float
pfirst
<
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_cvtss_f32
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
double
pfirst
<
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_cvtsd_f64
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
int64_t
pfirst
<
Packet8l
>
(
const
Packet8l
&
a
)
{
int64_t
x
=
_mm_extract_epi64_0
(
_mm512_extracti32x4_epi32
(
a
,
0
));
return
x
;
}
template
<
>
EIGEN_STRONG_INLINE
int
pfirst
<
Packet16i
>
(
const
Packet16i
&
a
)
{
#if EIGEN_GNUC_STRICT_LESS_THAN(11, 0, 0)
return
_mm_cvtsi128_si32
(
_mm512_castsi512_si128
(
a
));
#else
return
_mm512_cvtsi512_si32
(
a
);
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
preverse
(
const
Packet16f
&
a
)
{
return
_mm512_permutexvar_ps
(
_mm512_set_epi32
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
preverse
(
const
Packet8d
&
a
)
{
return
_mm512_permutexvar_pd
(
_mm512_set_epi32
(
0
,
0
,
0
,
1
,
0
,
2
,
0
,
3
,
0
,
4
,
0
,
5
,
0
,
6
,
0
,
7
),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
preverse
(
const
Packet16i
&
a
)
{
return
_mm512_permutexvar_epi32
(
_mm512_set_epi32
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
preverse
(
const
Packet8l
&
a
)
{
return
_mm512_permutexvar_epi64
(
_mm512_set_epi64
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pabs
(
const
Packet16f
&
a
)
{
// _mm512_abs_ps intrinsic not found, so hack around it
return
_mm512_castsi512_ps
(
_mm512_and_si512
(
_mm512_castps_si512
(
a
),
_mm512_set1_epi32
(
0x7fffffff
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pabs
(
const
Packet8d
&
a
)
{
// _mm512_abs_ps intrinsic not found, so hack around it
return
_mm512_castsi512_pd
(
_mm512_and_si512
(
_mm512_castpd_si512
(
a
),
_mm512_set1_epi64
(
0x7fffffffffffffff
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pabs
(
const
Packet16i
&
a
)
{
return
_mm512_abs_epi32
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pabs
(
const
Packet8l
&
a
)
{
return
_mm512_abs_epi64
(
a
);
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template
<
>
EIGEN_STRONG_INLINE
Packet16h
psignbit
(
const
Packet16h
&
a
)
{
return
_mm256_srai_epi16
(
a
,
15
);
}
#endif // EIGEN_VECTORIZE_AVX512FP16
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
psignbit
(
const
Packet16bf
&
a
)
{
return
_mm256_srai_epi16
(
a
,
15
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
psignbit
(
const
Packet16f
&
a
)
{
return
_mm512_castsi512_ps
(
_mm512_srai_epi32
(
_mm512_castps_si512
(
a
),
31
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
psignbit
(
const
Packet8d
&
a
)
{
return
_mm512_castsi512_pd
(
_mm512_srai_epi64
(
_mm512_castpd_si512
(
a
),
63
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pfrexp
<
Packet16f
>
(
const
Packet16f
&
a
,
Packet16f
&
exponent
)
{
return
pfrexp_generic
(
a
,
exponent
);
}
// Extract exponent without existence of Packet8l.
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pfrexp_generic_get_biased_exponent
(
const
Packet8d
&
a
)
{
const
Packet8d
cst_exp_mask
=
pset1frombits
<
Packet8d
>
(
static_cast
<
uint64_t
>
(
0x7ff0000000000000ull
));
#ifdef EIGEN_VECTORIZE_AVX512DQ
return
_mm512_cvtepi64_pd
(
_mm512_srli_epi64
(
_mm512_castpd_si512
(
pand
(
a
,
cst_exp_mask
)),
52
));
#else
return
_mm512_cvtepi32_pd
(
_mm512_cvtepi64_epi32
(
_mm512_srli_epi64
(
_mm512_castpd_si512
(
pand
(
a
,
cst_exp_mask
)),
52
)));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pfrexp
<
Packet8d
>
(
const
Packet8d
&
a
,
Packet8d
&
exponent
)
{
return
pfrexp_generic
(
a
,
exponent
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pldexp
<
Packet16f
>
(
const
Packet16f
&
a
,
const
Packet16f
&
exponent
)
{
return
pldexp_generic
(
a
,
exponent
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pldexp
<
Packet8d
>
(
const
Packet8d
&
a
,
const
Packet8d
&
exponent
)
{
// Clamp exponent to [-2099, 2099]
const
Packet8d
max_exponent
=
pset1
<
Packet8d
>
(
2099.0
);
const
Packet8i
e
=
_mm512_cvtpd_epi32
(
pmin
(
pmax
(
exponent
,
pnegate
(
max_exponent
)),
max_exponent
));
// Split 2^e into four factors and multiply.
const
Packet8i
bias
=
pset1
<
Packet8i
>
(
1023
);
Packet8i
b
=
parithmetic_shift_right
<
2
>
(
e
);
// floor(e/4)
// 2^b
const
Packet8i
permute_idx
=
_mm256_setr_epi32
(
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
);
Packet8i
hi
=
_mm256_permutevar8x32_epi32
(
padd
(
b
,
bias
),
permute_idx
);
Packet8i
lo
=
_mm256_slli_epi64
(
hi
,
52
);
hi
=
_mm256_slli_epi64
(
_mm256_srli_epi64
(
hi
,
32
),
52
);
Packet8d
c
=
_mm512_castsi512_pd
(
_mm512_inserti64x4
(
_mm512_castsi256_si512
(
lo
),
hi
,
1
));
Packet8d
out
=
pmul
(
pmul
(
pmul
(
a
,
c
),
c
),
c
);
// a * 2^(3b)
// 2^(e - 3b)
b
=
psub
(
psub
(
psub
(
e
,
b
),
b
),
b
);
// e - 3b
hi
=
_mm256_permutevar8x32_epi32
(
padd
(
b
,
bias
),
permute_idx
);
lo
=
_mm256_slli_epi64
(
hi
,
52
);
hi
=
_mm256_slli_epi64
(
_mm256_srli_epi64
(
hi
,
32
),
52
);
c
=
_mm512_castsi512_pd
(
_mm512_inserti64x4
(
_mm512_castsi256_si512
(
lo
),
hi
,
1
));
out
=
pmul
(
out
,
c
);
// a * 2^e
return
out
;
}
#ifdef EIGEN_VECTORIZE_AVX512DQ
// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
__m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \
__m256 OUTPUT##_1 = _mm512_extractf32x8_ps(INPUT, 1)
// AVX512F does not define _mm512_extracti32x8_epi32 to extract _m256i from _m512i
#define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \
__m256i OUTPUT##_0 = _mm512_extracti32x8_epi32(INPUT, 0); \
__m256i OUTPUT##_1 = _mm512_extracti32x8_epi32(INPUT, 1)
#else
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
__m256 OUTPUT##_0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 0)), \
_mm512_extractf32x4_ps(INPUT, 1), 1); \
__m256 OUTPUT##_1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 2)), \
_mm512_extractf32x4_ps(INPUT, 3), 1)
#define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \
__m256i OUTPUT##_0 = _mm256_insertf128_si256(_mm256_castsi128_si256(_mm512_extracti32x4_epi32(INPUT, 0)), \
_mm512_extracti32x4_epi32(INPUT, 1), 1); \
__m256i OUTPUT##_1 = _mm256_insertf128_si256(_mm256_castsi128_si256(_mm512_extracti32x4_epi32(INPUT, 2)), \
_mm512_extracti32x4_epi32(INPUT, 3), 1)
#endif
#ifdef EIGEN_VECTORIZE_AVX512DQ
#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
OUTPUT = _mm512_insertf32x8(_mm512_castps256_ps512(INPUTA), INPUTB, 1);
#define EIGEN_INSERT_8i_INTO_16i(OUTPUT, INPUTA, INPUTB) \
OUTPUT = _mm512_inserti32x8(_mm512_castsi256_si512(INPUTA), INPUTB, 1);
#else
#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
OUTPUT = _mm512_undefined_ps(); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 0), 0); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3);
#define EIGEN_INSERT_8i_INTO_16i(OUTPUT, INPUTA, INPUTB) \
OUTPUT = _mm512_undefined_epi32(); \
OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTA, 0), 0); \
OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTA, 1), 1); \
OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTB, 0), 2); \
OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTB, 1), 3);
#endif
template
<
>
EIGEN_STRONG_INLINE
float
predux
<
Packet16f
>
(
const
Packet16f
&
a
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
__m256
lane0
=
_mm512_extractf32x8_ps
(
a
,
0
);
__m256
lane1
=
_mm512_extractf32x8_ps
(
a
,
1
);
Packet8f
x
=
_mm256_add_ps
(
lane0
,
lane1
);
return
predux
<
Packet8f
>
(
x
);
#else
__m128
lane0
=
_mm512_extractf32x4_ps
(
a
,
0
);
__m128
lane1
=
_mm512_extractf32x4_ps
(
a
,
1
);
__m128
lane2
=
_mm512_extractf32x4_ps
(
a
,
2
);
__m128
lane3
=
_mm512_extractf32x4_ps
(
a
,
3
);
__m128
sum
=
_mm_add_ps
(
_mm_add_ps
(
lane0
,
lane1
),
_mm_add_ps
(
lane2
,
lane3
));
return
predux
<
Packet4f
>
(
sum
);
#endif
}
template
<
>
EIGEN_STRONG_INLINE
double
predux
<
Packet8d
>
(
const
Packet8d
&
a
)
{
__m256d
lane0
=
_mm512_extractf64x4_pd
(
a
,
0
);
__m256d
lane1
=
_mm512_extractf64x4_pd
(
a
,
1
);
__m256d
sum
=
_mm256_add_pd
(
lane0
,
lane1
);
return
predux
<
Packet4d
>
(
sum
);
}
template
<
>
EIGEN_STRONG_INLINE
int64_t
predux
<
Packet8l
>
(
const
Packet8l
&
a
)
{
return
_mm512_reduce_add_epi64
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
int
predux
<
Packet16i
>
(
const
Packet16i
&
a
)
{
return
_mm512_reduce_add_epi32
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8f
predux_half_dowto4
<
Packet16f
>
(
const
Packet16f
&
a
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
__m256
lane0
=
_mm512_extractf32x8_ps
(
a
,
0
);
__m256
lane1
=
_mm512_extractf32x8_ps
(
a
,
1
);
return
_mm256_add_ps
(
lane0
,
lane1
);
#else
__m128
lane0
=
_mm512_extractf32x4_ps
(
a
,
0
);
__m128
lane1
=
_mm512_extractf32x4_ps
(
a
,
1
);
__m128
lane2
=
_mm512_extractf32x4_ps
(
a
,
2
);
__m128
lane3
=
_mm512_extractf32x4_ps
(
a
,
3
);
__m128
sum0
=
_mm_add_ps
(
lane0
,
lane2
);
__m128
sum1
=
_mm_add_ps
(
lane1
,
lane3
);
return
_mm256_insertf128_ps
(
_mm256_castps128_ps256
(
sum0
),
sum1
,
1
);
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet4d
predux_half_dowto4
<
Packet8d
>
(
const
Packet8d
&
a
)
{
__m256d
lane0
=
_mm512_extractf64x4_pd
(
a
,
0
);
__m256d
lane1
=
_mm512_extractf64x4_pd
(
a
,
1
);
return
_mm256_add_pd
(
lane0
,
lane1
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8i
predux_half_dowto4
<
Packet16i
>
(
const
Packet16i
&
a
)
{
#ifdef EIGEN_VECTORIZE_AVX512DQ
__m256i
lane0
=
_mm512_extracti32x8_epi32
(
a
,
0
);
__m256i
lane1
=
_mm512_extracti32x8_epi32
(
a
,
1
);
return
_mm256_add_epi32
(
lane0
,
lane1
);
#else
__m128i
lane0
=
_mm512_extracti32x4_epi32
(
a
,
0
);
__m128i
lane1
=
_mm512_extracti32x4_epi32
(
a
,
1
);
__m128i
lane2
=
_mm512_extracti32x4_epi32
(
a
,
2
);
__m128i
lane3
=
_mm512_extracti32x4_epi32
(
a
,
3
);
__m128i
sum0
=
_mm_add_epi32
(
lane0
,
lane2
);
__m128i
sum1
=
_mm_add_epi32
(
lane1
,
lane3
);
return
_mm256_inserti128_si256
(
_mm256_castsi128_si256
(
sum0
),
sum1
,
1
);
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet4l
predux_half_dowto4
<
Packet8l
>
(
const
Packet8l
&
a
)
{
__m256i
lane0
=
_mm512_extracti64x4_epi64
(
a
,
0
);
__m256i
lane1
=
_mm512_extracti64x4_epi64
(
a
,
1
);
return
_mm256_add_epi64
(
lane0
,
lane1
);
}
template
<
>
EIGEN_STRONG_INLINE
float
predux_mul
<
Packet16f
>
(
const
Packet16f
&
a
)
{
// #ifdef EIGEN_VECTORIZE_AVX512DQ
#if 0
Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
Packet8f res = pmul(lane0, lane1);
res = pmul(res, _mm256_permute2f128_ps(res, res, 1));
res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
#else
__m128
lane0
=
_mm512_extractf32x4_ps
(
a
,
0
);
__m128
lane1
=
_mm512_extractf32x4_ps
(
a
,
1
);
__m128
lane2
=
_mm512_extractf32x4_ps
(
a
,
2
);
__m128
lane3
=
_mm512_extractf32x4_ps
(
a
,
3
);
__m128
res
=
pmul
(
pmul
(
lane0
,
lane1
),
pmul
(
lane2
,
lane3
));
res
=
pmul
(
res
,
_mm_permute_ps
(
res
,
_MM_SHUFFLE
(
0
,
0
,
3
,
2
)));
return
pfirst
(
pmul
(
res
,
_mm_permute_ps
(
res
,
_MM_SHUFFLE
(
0
,
0
,
0
,
1
))));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
double
predux_mul
<
Packet8d
>
(
const
Packet8d
&
a
)
{
__m256d
lane0
=
_mm512_extractf64x4_pd
(
a
,
0
);
__m256d
lane1
=
_mm512_extractf64x4_pd
(
a
,
1
);
__m256d
res
=
pmul
(
lane0
,
lane1
);
res
=
pmul
(
res
,
_mm256_permute2f128_pd
(
res
,
res
,
1
));
return
pfirst
(
pmul
(
res
,
_mm256_shuffle_pd
(
res
,
res
,
1
)));
}
template
<
>
EIGEN_STRONG_INLINE
int
predux_mul
<
Packet16i
>
(
const
Packet16i
&
a
)
{
return
_mm512_reduce_mul_epi32
(
a
);
}
#if EIGEN_COMP_MSVC
// MSVC's _mm512_reduce_mul_epi64 is borked, at least up to and including 1939.
// alignas(64) int64_t data[] = { 1,1,-1,-1,1,-1,-1,-1 };
// int64_t out = _mm512_reduce_mul_epi64(_mm512_load_epi64(data));
// produces garbage: 4294967295. It seems to happen whenever the output is supposed to be negative.
// Fall back to a manual approach:
template
<
>
EIGEN_STRONG_INLINE
int64_t
predux_mul
<
Packet8l
>
(
const
Packet8l
&
a
)
{
Packet4l
lane0
=
_mm512_extracti64x4_epi64
(
a
,
0
);
Packet4l
lane1
=
_mm512_extracti64x4_epi64
(
a
,
1
);
Packet4l
res
=
pmul
(
lane0
,
lane1
);
res
=
pmul
(
res
,
Packet4l
(
_mm256_permute2x128_si256
(
res
,
res
,
1
)));
res
=
pmul
(
res
,
Packet4l
(
_mm256_shuffle_epi32
(
res
,
0xE
)));
return
pfirst
(
res
);
}
#else
template
<
>
EIGEN_STRONG_INLINE
int64_t
predux_mul
<
Packet8l
>
(
const
Packet8l
&
a
)
{
return
_mm512_reduce_mul_epi64
(
a
);
}
#endif
template
<
>
EIGEN_STRONG_INLINE
float
predux_min
<
Packet16f
>
(
const
Packet16f
&
a
)
{
__m128
lane0
=
_mm512_extractf32x4_ps
(
a
,
0
);
__m128
lane1
=
_mm512_extractf32x4_ps
(
a
,
1
);
__m128
lane2
=
_mm512_extractf32x4_ps
(
a
,
2
);
__m128
lane3
=
_mm512_extractf32x4_ps
(
a
,
3
);
__m128
res
=
_mm_min_ps
(
_mm_min_ps
(
lane0
,
lane1
),
_mm_min_ps
(
lane2
,
lane3
));
res
=
_mm_min_ps
(
res
,
_mm_permute_ps
(
res
,
_MM_SHUFFLE
(
0
,
0
,
3
,
2
)));
return
pfirst
(
_mm_min_ps
(
res
,
_mm_permute_ps
(
res
,
_MM_SHUFFLE
(
0
,
0
,
0
,
1
))));
}
template
<
>
EIGEN_STRONG_INLINE
double
predux_min
<
Packet8d
>
(
const
Packet8d
&
a
)
{
__m256d
lane0
=
_mm512_extractf64x4_pd
(
a
,
0
);
__m256d
lane1
=
_mm512_extractf64x4_pd
(
a
,
1
);
__m256d
res
=
_mm256_min_pd
(
lane0
,
lane1
);
res
=
_mm256_min_pd
(
res
,
_mm256_permute2f128_pd
(
res
,
res
,
1
));
return
pfirst
(
_mm256_min_pd
(
res
,
_mm256_shuffle_pd
(
res
,
res
,
1
)));
}
template
<
>
EIGEN_STRONG_INLINE
int
predux_min
<
Packet16i
>
(
const
Packet16i
&
a
)
{
return
_mm512_reduce_min_epi32
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
int64_t
predux_min
<
Packet8l
>
(
const
Packet8l
&
a
)
{
return
_mm512_reduce_min_epi64
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
float
predux_max
<
Packet16f
>
(
const
Packet16f
&
a
)
{
__m128
lane0
=
_mm512_extractf32x4_ps
(
a
,
0
);
__m128
lane1
=
_mm512_extractf32x4_ps
(
a
,
1
);
__m128
lane2
=
_mm512_extractf32x4_ps
(
a
,
2
);
__m128
lane3
=
_mm512_extractf32x4_ps
(
a
,
3
);
__m128
res
=
_mm_max_ps
(
_mm_max_ps
(
lane0
,
lane1
),
_mm_max_ps
(
lane2
,
lane3
));
res
=
_mm_max_ps
(
res
,
_mm_permute_ps
(
res
,
_MM_SHUFFLE
(
0
,
0
,
3
,
2
)));
return
pfirst
(
_mm_max_ps
(
res
,
_mm_permute_ps
(
res
,
_MM_SHUFFLE
(
0
,
0
,
0
,
1
))));
}
template
<
>
EIGEN_STRONG_INLINE
double
predux_max
<
Packet8d
>
(
const
Packet8d
&
a
)
{
__m256d
lane0
=
_mm512_extractf64x4_pd
(
a
,
0
);
__m256d
lane1
=
_mm512_extractf64x4_pd
(
a
,
1
);
__m256d
res
=
_mm256_max_pd
(
lane0
,
lane1
);
res
=
_mm256_max_pd
(
res
,
_mm256_permute2f128_pd
(
res
,
res
,
1
));
return
pfirst
(
_mm256_max_pd
(
res
,
_mm256_shuffle_pd
(
res
,
res
,
1
)));
}
template
<
>
EIGEN_STRONG_INLINE
int
predux_max
<
Packet16i
>
(
const
Packet16i
&
a
)
{
return
_mm512_reduce_max_epi32
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
int64_t
predux_max
<
Packet8l
>
(
const
Packet8l
&
a
)
{
return
_mm512_reduce_max_epi64
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
bool
predux_any
(
const
Packet16f
&
a
)
{
return
_mm512_reduce_or_epi32
(
_mm512_castps_si512
(
a
))
!=
0
;
}
template
<
>
EIGEN_STRONG_INLINE
bool
predux_any
(
const
Packet16i
&
a
)
{
return
_mm512_reduce_or_epi32
(
a
)
!=
0
;
}
template
<
>
EIGEN_STRONG_INLINE
bool
predux_any
(
const
Packet8d
&
a
)
{
return
_mm512_reduce_or_epi64
(
_mm512_castpd_si512
(
a
))
!=
0
;
}
template
<
>
EIGEN_STRONG_INLINE
bool
predux_any
(
const
Packet8l
&
a
)
{
return
_mm512_reduce_or_epi64
(
a
)
!=
0
;
}
#define PACK_OUTPUT(OUTPUT, INPUT, INDEX, STRIDE) \
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]);
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet16f
,
16
>&
kernel
)
{
__m512
T0
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512
T1
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512
T2
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512
T3
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512
T4
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512
T5
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512
T6
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
__m512
T7
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
__m512
T8
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
8
],
kernel
.
packet
[
9
]);
__m512
T9
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
8
],
kernel
.
packet
[
9
]);
__m512
T10
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
10
],
kernel
.
packet
[
11
]);
__m512
T11
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
10
],
kernel
.
packet
[
11
]);
__m512
T12
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
12
],
kernel
.
packet
[
13
]);
__m512
T13
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
12
],
kernel
.
packet
[
13
]);
__m512
T14
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
14
],
kernel
.
packet
[
15
]);
__m512
T15
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
14
],
kernel
.
packet
[
15
]);
__m512
S0
=
_mm512_shuffle_ps
(
T0
,
T2
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S1
=
_mm512_shuffle_ps
(
T0
,
T2
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512
S2
=
_mm512_shuffle_ps
(
T1
,
T3
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S3
=
_mm512_shuffle_ps
(
T1
,
T3
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512
S4
=
_mm512_shuffle_ps
(
T4
,
T6
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S5
=
_mm512_shuffle_ps
(
T4
,
T6
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512
S6
=
_mm512_shuffle_ps
(
T5
,
T7
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S7
=
_mm512_shuffle_ps
(
T5
,
T7
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512
S8
=
_mm512_shuffle_ps
(
T8
,
T10
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S9
=
_mm512_shuffle_ps
(
T8
,
T10
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512
S10
=
_mm512_shuffle_ps
(
T9
,
T11
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S11
=
_mm512_shuffle_ps
(
T9
,
T11
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512
S12
=
_mm512_shuffle_ps
(
T12
,
T14
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S13
=
_mm512_shuffle_ps
(
T12
,
T14
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512
S14
=
_mm512_shuffle_ps
(
T13
,
T15
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S15
=
_mm512_shuffle_ps
(
T13
,
T15
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
EIGEN_EXTRACT_8f_FROM_16f
(
S0
,
S0
);
EIGEN_EXTRACT_8f_FROM_16f
(
S1
,
S1
);
EIGEN_EXTRACT_8f_FROM_16f
(
S2
,
S2
);
EIGEN_EXTRACT_8f_FROM_16f
(
S3
,
S3
);
EIGEN_EXTRACT_8f_FROM_16f
(
S4
,
S4
);
EIGEN_EXTRACT_8f_FROM_16f
(
S5
,
S5
);
EIGEN_EXTRACT_8f_FROM_16f
(
S6
,
S6
);
EIGEN_EXTRACT_8f_FROM_16f
(
S7
,
S7
);
EIGEN_EXTRACT_8f_FROM_16f
(
S8
,
S8
);
EIGEN_EXTRACT_8f_FROM_16f
(
S9
,
S9
);
EIGEN_EXTRACT_8f_FROM_16f
(
S10
,
S10
);
EIGEN_EXTRACT_8f_FROM_16f
(
S11
,
S11
);
EIGEN_EXTRACT_8f_FROM_16f
(
S12
,
S12
);
EIGEN_EXTRACT_8f_FROM_16f
(
S13
,
S13
);
EIGEN_EXTRACT_8f_FROM_16f
(
S14
,
S14
);
EIGEN_EXTRACT_8f_FROM_16f
(
S15
,
S15
);
PacketBlock
<
Packet8f
,
32
>
tmp
;
tmp
.
packet
[
0
]
=
_mm256_permute2f128_ps
(
S0_0
,
S4_0
,
0x20
);
tmp
.
packet
[
1
]
=
_mm256_permute2f128_ps
(
S1_0
,
S5_0
,
0x20
);
tmp
.
packet
[
2
]
=
_mm256_permute2f128_ps
(
S2_0
,
S6_0
,
0x20
);
tmp
.
packet
[
3
]
=
_mm256_permute2f128_ps
(
S3_0
,
S7_0
,
0x20
);
tmp
.
packet
[
4
]
=
_mm256_permute2f128_ps
(
S0_0
,
S4_0
,
0x31
);
tmp
.
packet
[
5
]
=
_mm256_permute2f128_ps
(
S1_0
,
S5_0
,
0x31
);
tmp
.
packet
[
6
]
=
_mm256_permute2f128_ps
(
S2_0
,
S6_0
,
0x31
);
tmp
.
packet
[
7
]
=
_mm256_permute2f128_ps
(
S3_0
,
S7_0
,
0x31
);
tmp
.
packet
[
8
]
=
_mm256_permute2f128_ps
(
S0_1
,
S4_1
,
0x20
);
tmp
.
packet
[
9
]
=
_mm256_permute2f128_ps
(
S1_1
,
S5_1
,
0x20
);
tmp
.
packet
[
10
]
=
_mm256_permute2f128_ps
(
S2_1
,
S6_1
,
0x20
);
tmp
.
packet
[
11
]
=
_mm256_permute2f128_ps
(
S3_1
,
S7_1
,
0x20
);
tmp
.
packet
[
12
]
=
_mm256_permute2f128_ps
(
S0_1
,
S4_1
,
0x31
);
tmp
.
packet
[
13
]
=
_mm256_permute2f128_ps
(
S1_1
,
S5_1
,
0x31
);
tmp
.
packet
[
14
]
=
_mm256_permute2f128_ps
(
S2_1
,
S6_1
,
0x31
);
tmp
.
packet
[
15
]
=
_mm256_permute2f128_ps
(
S3_1
,
S7_1
,
0x31
);
// Second set of _m256 outputs
tmp
.
packet
[
16
]
=
_mm256_permute2f128_ps
(
S8_0
,
S12_0
,
0x20
);
tmp
.
packet
[
17
]
=
_mm256_permute2f128_ps
(
S9_0
,
S13_0
,
0x20
);
tmp
.
packet
[
18
]
=
_mm256_permute2f128_ps
(
S10_0
,
S14_0
,
0x20
);
tmp
.
packet
[
19
]
=
_mm256_permute2f128_ps
(
S11_0
,
S15_0
,
0x20
);
tmp
.
packet
[
20
]
=
_mm256_permute2f128_ps
(
S8_0
,
S12_0
,
0x31
);
tmp
.
packet
[
21
]
=
_mm256_permute2f128_ps
(
S9_0
,
S13_0
,
0x31
);
tmp
.
packet
[
22
]
=
_mm256_permute2f128_ps
(
S10_0
,
S14_0
,
0x31
);
tmp
.
packet
[
23
]
=
_mm256_permute2f128_ps
(
S11_0
,
S15_0
,
0x31
);
tmp
.
packet
[
24
]
=
_mm256_permute2f128_ps
(
S8_1
,
S12_1
,
0x20
);
tmp
.
packet
[
25
]
=
_mm256_permute2f128_ps
(
S9_1
,
S13_1
,
0x20
);
tmp
.
packet
[
26
]
=
_mm256_permute2f128_ps
(
S10_1
,
S14_1
,
0x20
);
tmp
.
packet
[
27
]
=
_mm256_permute2f128_ps
(
S11_1
,
S15_1
,
0x20
);
tmp
.
packet
[
28
]
=
_mm256_permute2f128_ps
(
S8_1
,
S12_1
,
0x31
);
tmp
.
packet
[
29
]
=
_mm256_permute2f128_ps
(
S9_1
,
S13_1
,
0x31
);
tmp
.
packet
[
30
]
=
_mm256_permute2f128_ps
(
S10_1
,
S14_1
,
0x31
);
tmp
.
packet
[
31
]
=
_mm256_permute2f128_ps
(
S11_1
,
S15_1
,
0x31
);
// Pack them into the output
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
0
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
1
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
2
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
3
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
4
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
5
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
6
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
7
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
8
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
9
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
10
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
11
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
12
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
13
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
14
,
16
);
PACK_OUTPUT
(
kernel
.
packet
,
tmp
.
packet
,
15
,
16
);
}
#define PACK_OUTPUT_2(OUTPUT, INPUT, INDEX, STRIDE) \
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], INPUT[2 * INDEX + STRIDE]);
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet16f
,
8
>&
kernel
)
{
__m512
T0
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512
T1
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512
T2
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512
T3
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512
T4
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512
T5
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512
T6
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
__m512
T7
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
kernel
.
packet
[
0
]
=
_mm512_castpd_ps
(
_mm512_unpacklo_pd
(
_mm512_castps_pd
(
T0
),
_mm512_castps_pd
(
T2
)));
kernel
.
packet
[
1
]
=
_mm512_castpd_ps
(
_mm512_unpackhi_pd
(
_mm512_castps_pd
(
T0
),
_mm512_castps_pd
(
T2
)));
kernel
.
packet
[
2
]
=
_mm512_castpd_ps
(
_mm512_unpacklo_pd
(
_mm512_castps_pd
(
T1
),
_mm512_castps_pd
(
T3
)));
kernel
.
packet
[
3
]
=
_mm512_castpd_ps
(
_mm512_unpackhi_pd
(
_mm512_castps_pd
(
T1
),
_mm512_castps_pd
(
T3
)));
kernel
.
packet
[
4
]
=
_mm512_castpd_ps
(
_mm512_unpacklo_pd
(
_mm512_castps_pd
(
T4
),
_mm512_castps_pd
(
T6
)));
kernel
.
packet
[
5
]
=
_mm512_castpd_ps
(
_mm512_unpackhi_pd
(
_mm512_castps_pd
(
T4
),
_mm512_castps_pd
(
T6
)));
kernel
.
packet
[
6
]
=
_mm512_castpd_ps
(
_mm512_unpacklo_pd
(
_mm512_castps_pd
(
T5
),
_mm512_castps_pd
(
T7
)));
kernel
.
packet
[
7
]
=
_mm512_castpd_ps
(
_mm512_unpackhi_pd
(
_mm512_castps_pd
(
T5
),
_mm512_castps_pd
(
T7
)));
T0
=
_mm512_shuffle_f32x4
(
kernel
.
packet
[
0
],
kernel
.
packet
[
4
],
0x44
);
T1
=
_mm512_shuffle_f32x4
(
kernel
.
packet
[
0
],
kernel
.
packet
[
4
],
0xee
);
T2
=
_mm512_shuffle_f32x4
(
kernel
.
packet
[
1
],
kernel
.
packet
[
5
],
0x44
);
T3
=
_mm512_shuffle_f32x4
(
kernel
.
packet
[
1
],
kernel
.
packet
[
5
],
0xee
);
T4
=
_mm512_shuffle_f32x4
(
kernel
.
packet
[
2
],
kernel
.
packet
[
6
],
0x44
);
T5
=
_mm512_shuffle_f32x4
(
kernel
.
packet
[
2
],
kernel
.
packet
[
6
],
0xee
);
T6
=
_mm512_shuffle_f32x4
(
kernel
.
packet
[
3
],
kernel
.
packet
[
7
],
0x44
);
T7
=
_mm512_shuffle_f32x4
(
kernel
.
packet
[
3
],
kernel
.
packet
[
7
],
0xee
);
kernel
.
packet
[
0
]
=
_mm512_shuffle_f32x4
(
T0
,
T2
,
0x88
);
kernel
.
packet
[
2
]
=
_mm512_shuffle_f32x4
(
T0
,
T2
,
0xdd
);
kernel
.
packet
[
1
]
=
_mm512_shuffle_f32x4
(
T4
,
T6
,
0x88
);
kernel
.
packet
[
3
]
=
_mm512_shuffle_f32x4
(
T4
,
T6
,
0xdd
);
kernel
.
packet
[
4
]
=
_mm512_shuffle_f32x4
(
T1
,
T3
,
0x88
);
kernel
.
packet
[
6
]
=
_mm512_shuffle_f32x4
(
T1
,
T3
,
0xdd
);
kernel
.
packet
[
5
]
=
_mm512_shuffle_f32x4
(
T5
,
T7
,
0x88
);
kernel
.
packet
[
7
]
=
_mm512_shuffle_f32x4
(
T5
,
T7
,
0xdd
);
}
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet16f
,
4
>&
kernel
)
{
__m512
T0
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512
T1
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512
T2
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512
T3
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512
S0
=
_mm512_shuffle_ps
(
T0
,
T2
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S1
=
_mm512_shuffle_ps
(
T0
,
T2
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512
S2
=
_mm512_shuffle_ps
(
T1
,
T3
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512
S3
=
_mm512_shuffle_ps
(
T1
,
T3
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
EIGEN_EXTRACT_8f_FROM_16f
(
S0
,
S0
);
EIGEN_EXTRACT_8f_FROM_16f
(
S1
,
S1
);
EIGEN_EXTRACT_8f_FROM_16f
(
S2
,
S2
);
EIGEN_EXTRACT_8f_FROM_16f
(
S3
,
S3
);
PacketBlock
<
Packet8f
,
8
>
tmp
;
tmp
.
packet
[
0
]
=
_mm256_permute2f128_ps
(
S0_0
,
S1_0
,
0x20
);
tmp
.
packet
[
1
]
=
_mm256_permute2f128_ps
(
S2_0
,
S3_0
,
0x20
);
tmp
.
packet
[
2
]
=
_mm256_permute2f128_ps
(
S0_0
,
S1_0
,
0x31
);
tmp
.
packet
[
3
]
=
_mm256_permute2f128_ps
(
S2_0
,
S3_0
,
0x31
);
tmp
.
packet
[
4
]
=
_mm256_permute2f128_ps
(
S0_1
,
S1_1
,
0x20
);
tmp
.
packet
[
5
]
=
_mm256_permute2f128_ps
(
S2_1
,
S3_1
,
0x20
);
tmp
.
packet
[
6
]
=
_mm256_permute2f128_ps
(
S0_1
,
S1_1
,
0x31
);
tmp
.
packet
[
7
]
=
_mm256_permute2f128_ps
(
S2_1
,
S3_1
,
0x31
);
PACK_OUTPUT_2
(
kernel
.
packet
,
tmp
.
packet
,
0
,
1
);
PACK_OUTPUT_2
(
kernel
.
packet
,
tmp
.
packet
,
1
,
1
);
PACK_OUTPUT_2
(
kernel
.
packet
,
tmp
.
packet
,
2
,
1
);
PACK_OUTPUT_2
(
kernel
.
packet
,
tmp
.
packet
,
3
,
1
);
}
#define PACK_OUTPUT_SQ_D(OUTPUT, INPUT, INDEX, STRIDE) \
OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[INDEX], 0); \
OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[INDEX + STRIDE], 1);
#define PACK_OUTPUT_D(OUTPUT, INPUT, INDEX, STRIDE) \
OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \
OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1);
#define PACK_OUTPUT_L(OUTPUT, INPUT, INDEX, STRIDE) \
OUTPUT[INDEX] = _mm512_inserti64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \
OUTPUT[INDEX] = _mm512_inserti64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1);
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet8d
,
4
>&
kernel
)
{
__m512d
T0
=
_mm512_shuffle_pd
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
],
0
);
__m512d
T1
=
_mm512_shuffle_pd
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
],
0xff
);
__m512d
T2
=
_mm512_shuffle_pd
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
],
0
);
__m512d
T3
=
_mm512_shuffle_pd
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
],
0xff
);
PacketBlock
<
Packet4d
,
8
>
tmp
;
tmp
.
packet
[
0
]
=
_mm256_permute2f128_pd
(
_mm512_extractf64x4_pd
(
T0
,
0
),
_mm512_extractf64x4_pd
(
T2
,
0
),
0x20
);
tmp
.
packet
[
1
]
=
_mm256_permute2f128_pd
(
_mm512_extractf64x4_pd
(
T1
,
0
),
_mm512_extractf64x4_pd
(
T3
,
0
),
0x20
);
tmp
.
packet
[
2
]
=
_mm256_permute2f128_pd
(
_mm512_extractf64x4_pd
(
T0
,
0
),
_mm512_extractf64x4_pd
(
T2
,
0
),
0x31
);
tmp
.
packet
[
3
]
=
_mm256_permute2f128_pd
(
_mm512_extractf64x4_pd
(
T1
,
0
),
_mm512_extractf64x4_pd
(
T3
,
0
),
0x31
);
tmp
.
packet
[
4
]
=
_mm256_permute2f128_pd
(
_mm512_extractf64x4_pd
(
T0
,
1
),
_mm512_extractf64x4_pd
(
T2
,
1
),
0x20
);
tmp
.
packet
[
5
]
=
_mm256_permute2f128_pd
(
_mm512_extractf64x4_pd
(
T1
,
1
),
_mm512_extractf64x4_pd
(
T3
,
1
),
0x20
);
tmp
.
packet
[
6
]
=
_mm256_permute2f128_pd
(
_mm512_extractf64x4_pd
(
T0
,
1
),
_mm512_extractf64x4_pd
(
T2
,
1
),
0x31
);
tmp
.
packet
[
7
]
=
_mm256_permute2f128_pd
(
_mm512_extractf64x4_pd
(
T1
,
1
),
_mm512_extractf64x4_pd
(
T3
,
1
),
0x31
);
PACK_OUTPUT_D
(
kernel
.
packet
,
tmp
.
packet
,
0
,
1
);
PACK_OUTPUT_D
(
kernel
.
packet
,
tmp
.
packet
,
1
,
1
);
PACK_OUTPUT_D
(
kernel
.
packet
,
tmp
.
packet
,
2
,
1
);
PACK_OUTPUT_D
(
kernel
.
packet
,
tmp
.
packet
,
3
,
1
);
}
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet8d
,
8
>&
kernel
)
{
__m512d
T0
=
_mm512_unpacklo_pd
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512d
T1
=
_mm512_unpackhi_pd
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512d
T2
=
_mm512_unpacklo_pd
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512d
T3
=
_mm512_unpackhi_pd
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512d
T4
=
_mm512_unpacklo_pd
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512d
T5
=
_mm512_unpackhi_pd
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512d
T6
=
_mm512_unpacklo_pd
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
__m512d
T7
=
_mm512_unpackhi_pd
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
kernel
.
packet
[
0
]
=
_mm512_permutex_pd
(
T2
,
0x4E
);
kernel
.
packet
[
0
]
=
_mm512_mask_blend_pd
(
0xCC
,
T0
,
kernel
.
packet
[
0
]);
kernel
.
packet
[
2
]
=
_mm512_permutex_pd
(
T0
,
0x4E
);
kernel
.
packet
[
2
]
=
_mm512_mask_blend_pd
(
0xCC
,
kernel
.
packet
[
2
],
T2
);
kernel
.
packet
[
1
]
=
_mm512_permutex_pd
(
T3
,
0x4E
);
kernel
.
packet
[
1
]
=
_mm512_mask_blend_pd
(
0xCC
,
T1
,
kernel
.
packet
[
1
]);
kernel
.
packet
[
3
]
=
_mm512_permutex_pd
(
T1
,
0x4E
);
kernel
.
packet
[
3
]
=
_mm512_mask_blend_pd
(
0xCC
,
kernel
.
packet
[
3
],
T3
);
kernel
.
packet
[
4
]
=
_mm512_permutex_pd
(
T6
,
0x4E
);
kernel
.
packet
[
4
]
=
_mm512_mask_blend_pd
(
0xCC
,
T4
,
kernel
.
packet
[
4
]);
kernel
.
packet
[
6
]
=
_mm512_permutex_pd
(
T4
,
0x4E
);
kernel
.
packet
[
6
]
=
_mm512_mask_blend_pd
(
0xCC
,
kernel
.
packet
[
6
],
T6
);
kernel
.
packet
[
5
]
=
_mm512_permutex_pd
(
T7
,
0x4E
);
kernel
.
packet
[
5
]
=
_mm512_mask_blend_pd
(
0xCC
,
T5
,
kernel
.
packet
[
5
]);
kernel
.
packet
[
7
]
=
_mm512_permutex_pd
(
T5
,
0x4E
);
kernel
.
packet
[
7
]
=
_mm512_mask_blend_pd
(
0xCC
,
kernel
.
packet
[
7
],
T7
);
T0
=
_mm512_shuffle_f64x2
(
kernel
.
packet
[
4
],
kernel
.
packet
[
4
],
0x4E
);
T0
=
_mm512_mask_blend_pd
(
0xF0
,
kernel
.
packet
[
0
],
T0
);
T4
=
_mm512_shuffle_f64x2
(
kernel
.
packet
[
0
],
kernel
.
packet
[
0
],
0x4E
);
T4
=
_mm512_mask_blend_pd
(
0xF0
,
T4
,
kernel
.
packet
[
4
]);
T1
=
_mm512_shuffle_f64x2
(
kernel
.
packet
[
5
],
kernel
.
packet
[
5
],
0x4E
);
T1
=
_mm512_mask_blend_pd
(
0xF0
,
kernel
.
packet
[
1
],
T1
);
T5
=
_mm512_shuffle_f64x2
(
kernel
.
packet
[
1
],
kernel
.
packet
[
1
],
0x4E
);
T5
=
_mm512_mask_blend_pd
(
0xF0
,
T5
,
kernel
.
packet
[
5
]);
T2
=
_mm512_shuffle_f64x2
(
kernel
.
packet
[
6
],
kernel
.
packet
[
6
],
0x4E
);
T2
=
_mm512_mask_blend_pd
(
0xF0
,
kernel
.
packet
[
2
],
T2
);
T6
=
_mm512_shuffle_f64x2
(
kernel
.
packet
[
2
],
kernel
.
packet
[
2
],
0x4E
);
T6
=
_mm512_mask_blend_pd
(
0xF0
,
T6
,
kernel
.
packet
[
6
]);
T3
=
_mm512_shuffle_f64x2
(
kernel
.
packet
[
7
],
kernel
.
packet
[
7
],
0x4E
);
T3
=
_mm512_mask_blend_pd
(
0xF0
,
kernel
.
packet
[
3
],
T3
);
T7
=
_mm512_shuffle_f64x2
(
kernel
.
packet
[
3
],
kernel
.
packet
[
3
],
0x4E
);
T7
=
_mm512_mask_blend_pd
(
0xF0
,
T7
,
kernel
.
packet
[
7
]);
kernel
.
packet
[
0
]
=
T0
;
kernel
.
packet
[
1
]
=
T1
;
kernel
.
packet
[
2
]
=
T2
;
kernel
.
packet
[
3
]
=
T3
;
kernel
.
packet
[
4
]
=
T4
;
kernel
.
packet
[
5
]
=
T5
;
kernel
.
packet
[
6
]
=
T6
;
kernel
.
packet
[
7
]
=
T7
;
}
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet8l
,
4
>&
kernel
)
{
__m512i
T0
=
_mm512_castpd_si512
(
_mm512_shuffle_pd
(
_mm512_castsi512_pd
(
kernel
.
packet
[
0
]),
_mm512_castsi512_pd
(
kernel
.
packet
[
1
]),
0
));
__m512i
T1
=
_mm512_castpd_si512
(
_mm512_shuffle_pd
(
_mm512_castsi512_pd
(
kernel
.
packet
[
0
]),
_mm512_castsi512_pd
(
kernel
.
packet
[
1
]),
0xff
));
__m512i
T2
=
_mm512_castpd_si512
(
_mm512_shuffle_pd
(
_mm512_castsi512_pd
(
kernel
.
packet
[
2
]),
_mm512_castsi512_pd
(
kernel
.
packet
[
3
]),
0
));
__m512i
T3
=
_mm512_castpd_si512
(
_mm512_shuffle_pd
(
_mm512_castsi512_pd
(
kernel
.
packet
[
2
]),
_mm512_castsi512_pd
(
kernel
.
packet
[
3
]),
0xff
));
PacketBlock
<
Packet4l
,
8
>
tmp
;
tmp
.
packet
[
0
]
=
_mm256_permute2x128_si256
(
_mm512_extracti64x4_epi64
(
T0
,
0
),
_mm512_extracti64x4_epi64
(
T2
,
0
),
0x20
);
tmp
.
packet
[
1
]
=
_mm256_permute2x128_si256
(
_mm512_extracti64x4_epi64
(
T1
,
0
),
_mm512_extracti64x4_epi64
(
T3
,
0
),
0x20
);
tmp
.
packet
[
2
]
=
_mm256_permute2x128_si256
(
_mm512_extracti64x4_epi64
(
T0
,
0
),
_mm512_extracti64x4_epi64
(
T2
,
0
),
0x31
);
tmp
.
packet
[
3
]
=
_mm256_permute2x128_si256
(
_mm512_extracti64x4_epi64
(
T1
,
0
),
_mm512_extracti64x4_epi64
(
T3
,
0
),
0x31
);
tmp
.
packet
[
4
]
=
_mm256_permute2x128_si256
(
_mm512_extracti64x4_epi64
(
T0
,
1
),
_mm512_extracti64x4_epi64
(
T2
,
1
),
0x20
);
tmp
.
packet
[
5
]
=
_mm256_permute2x128_si256
(
_mm512_extracti64x4_epi64
(
T1
,
1
),
_mm512_extracti64x4_epi64
(
T3
,
1
),
0x20
);
tmp
.
packet
[
6
]
=
_mm256_permute2x128_si256
(
_mm512_extracti64x4_epi64
(
T0
,
1
),
_mm512_extracti64x4_epi64
(
T2
,
1
),
0x31
);
tmp
.
packet
[
7
]
=
_mm256_permute2x128_si256
(
_mm512_extracti64x4_epi64
(
T1
,
1
),
_mm512_extracti64x4_epi64
(
T3
,
1
),
0x31
);
PACK_OUTPUT_L
(
kernel
.
packet
,
tmp
.
packet
,
0
,
1
);
PACK_OUTPUT_L
(
kernel
.
packet
,
tmp
.
packet
,
1
,
1
);
PACK_OUTPUT_L
(
kernel
.
packet
,
tmp
.
packet
,
2
,
1
);
PACK_OUTPUT_L
(
kernel
.
packet
,
tmp
.
packet
,
3
,
1
);
}
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet8l
,
8
>&
kernel
)
{
__m512i
T0
=
_mm512_unpacklo_epi64
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512i
T1
=
_mm512_unpackhi_epi64
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512i
T2
=
_mm512_unpacklo_epi64
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512i
T3
=
_mm512_unpackhi_epi64
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512i
T4
=
_mm512_unpacklo_epi64
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512i
T5
=
_mm512_unpackhi_epi64
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512i
T6
=
_mm512_unpacklo_epi64
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
__m512i
T7
=
_mm512_unpackhi_epi64
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
kernel
.
packet
[
0
]
=
_mm512_permutex_epi64
(
T2
,
0x4E
);
kernel
.
packet
[
0
]
=
_mm512_mask_blend_epi64
(
0xCC
,
T0
,
kernel
.
packet
[
0
]);
kernel
.
packet
[
2
]
=
_mm512_permutex_epi64
(
T0
,
0x4E
);
kernel
.
packet
[
2
]
=
_mm512_mask_blend_epi64
(
0xCC
,
kernel
.
packet
[
2
],
T2
);
kernel
.
packet
[
1
]
=
_mm512_permutex_epi64
(
T3
,
0x4E
);
kernel
.
packet
[
1
]
=
_mm512_mask_blend_epi64
(
0xCC
,
T1
,
kernel
.
packet
[
1
]);
kernel
.
packet
[
3
]
=
_mm512_permutex_epi64
(
T1
,
0x4E
);
kernel
.
packet
[
3
]
=
_mm512_mask_blend_epi64
(
0xCC
,
kernel
.
packet
[
3
],
T3
);
kernel
.
packet
[
4
]
=
_mm512_permutex_epi64
(
T6
,
0x4E
);
kernel
.
packet
[
4
]
=
_mm512_mask_blend_epi64
(
0xCC
,
T4
,
kernel
.
packet
[
4
]);
kernel
.
packet
[
6
]
=
_mm512_permutex_epi64
(
T4
,
0x4E
);
kernel
.
packet
[
6
]
=
_mm512_mask_blend_epi64
(
0xCC
,
kernel
.
packet
[
6
],
T6
);
kernel
.
packet
[
5
]
=
_mm512_permutex_epi64
(
T7
,
0x4E
);
kernel
.
packet
[
5
]
=
_mm512_mask_blend_epi64
(
0xCC
,
T5
,
kernel
.
packet
[
5
]);
kernel
.
packet
[
7
]
=
_mm512_permutex_epi64
(
T5
,
0x4E
);
kernel
.
packet
[
7
]
=
_mm512_mask_blend_epi64
(
0xCC
,
kernel
.
packet
[
7
],
T7
);
T0
=
_mm512_shuffle_i64x2
(
kernel
.
packet
[
4
],
kernel
.
packet
[
4
],
0x4E
);
T0
=
_mm512_mask_blend_epi64
(
0xF0
,
kernel
.
packet
[
0
],
T0
);
T4
=
_mm512_shuffle_i64x2
(
kernel
.
packet
[
0
],
kernel
.
packet
[
0
],
0x4E
);
T4
=
_mm512_mask_blend_epi64
(
0xF0
,
T4
,
kernel
.
packet
[
4
]);
T1
=
_mm512_shuffle_i64x2
(
kernel
.
packet
[
5
],
kernel
.
packet
[
5
],
0x4E
);
T1
=
_mm512_mask_blend_epi64
(
0xF0
,
kernel
.
packet
[
1
],
T1
);
T5
=
_mm512_shuffle_i64x2
(
kernel
.
packet
[
1
],
kernel
.
packet
[
1
],
0x4E
);
T5
=
_mm512_mask_blend_epi64
(
0xF0
,
T5
,
kernel
.
packet
[
5
]);
T2
=
_mm512_shuffle_i64x2
(
kernel
.
packet
[
6
],
kernel
.
packet
[
6
],
0x4E
);
T2
=
_mm512_mask_blend_epi64
(
0xF0
,
kernel
.
packet
[
2
],
T2
);
T6
=
_mm512_shuffle_i64x2
(
kernel
.
packet
[
2
],
kernel
.
packet
[
2
],
0x4E
);
T6
=
_mm512_mask_blend_epi64
(
0xF0
,
T6
,
kernel
.
packet
[
6
]);
T3
=
_mm512_shuffle_i64x2
(
kernel
.
packet
[
7
],
kernel
.
packet
[
7
],
0x4E
);
T3
=
_mm512_mask_blend_epi64
(
0xF0
,
kernel
.
packet
[
3
],
T3
);
T7
=
_mm512_shuffle_i64x2
(
kernel
.
packet
[
3
],
kernel
.
packet
[
3
],
0x4E
);
T7
=
_mm512_mask_blend_epi64
(
0xF0
,
T7
,
kernel
.
packet
[
7
]);
kernel
.
packet
[
0
]
=
T0
;
kernel
.
packet
[
1
]
=
T1
;
kernel
.
packet
[
2
]
=
T2
;
kernel
.
packet
[
3
]
=
T3
;
kernel
.
packet
[
4
]
=
T4
;
kernel
.
packet
[
5
]
=
T5
;
kernel
.
packet
[
6
]
=
T6
;
kernel
.
packet
[
7
]
=
T7
;
}
#define PACK_OUTPUT_I32(OUTPUT, INPUT, INDEX, STRIDE) \
EIGEN_INSERT_8i_INTO_16i(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]);
#define PACK_OUTPUT_I32_2(OUTPUT, INPUT, INDEX, STRIDE) \
EIGEN_INSERT_8i_INTO_16i(OUTPUT[INDEX], INPUT[2 * INDEX], INPUT[2 * INDEX + STRIDE]);
#define SHUFFLE_EPI32(A, B, M) _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(A), _mm512_castsi512_ps(B), M))
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet16i
,
16
>&
kernel
)
{
__m512i
T0
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512i
T1
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512i
T2
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512i
T3
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512i
T4
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512i
T5
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512i
T6
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
__m512i
T7
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
__m512i
T8
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
8
],
kernel
.
packet
[
9
]);
__m512i
T9
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
8
],
kernel
.
packet
[
9
]);
__m512i
T10
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
10
],
kernel
.
packet
[
11
]);
__m512i
T11
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
10
],
kernel
.
packet
[
11
]);
__m512i
T12
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
12
],
kernel
.
packet
[
13
]);
__m512i
T13
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
12
],
kernel
.
packet
[
13
]);
__m512i
T14
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
14
],
kernel
.
packet
[
15
]);
__m512i
T15
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
14
],
kernel
.
packet
[
15
]);
__m512i
S0
=
SHUFFLE_EPI32
(
T0
,
T2
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S1
=
SHUFFLE_EPI32
(
T0
,
T2
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512i
S2
=
SHUFFLE_EPI32
(
T1
,
T3
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S3
=
SHUFFLE_EPI32
(
T1
,
T3
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512i
S4
=
SHUFFLE_EPI32
(
T4
,
T6
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S5
=
SHUFFLE_EPI32
(
T4
,
T6
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512i
S6
=
SHUFFLE_EPI32
(
T5
,
T7
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S7
=
SHUFFLE_EPI32
(
T5
,
T7
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512i
S8
=
SHUFFLE_EPI32
(
T8
,
T10
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S9
=
SHUFFLE_EPI32
(
T8
,
T10
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512i
S10
=
SHUFFLE_EPI32
(
T9
,
T11
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S11
=
SHUFFLE_EPI32
(
T9
,
T11
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512i
S12
=
SHUFFLE_EPI32
(
T12
,
T14
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S13
=
SHUFFLE_EPI32
(
T12
,
T14
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512i
S14
=
SHUFFLE_EPI32
(
T13
,
T15
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S15
=
SHUFFLE_EPI32
(
T13
,
T15
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
EIGEN_EXTRACT_8i_FROM_16i
(
S0
,
S0
);
EIGEN_EXTRACT_8i_FROM_16i
(
S1
,
S1
);
EIGEN_EXTRACT_8i_FROM_16i
(
S2
,
S2
);
EIGEN_EXTRACT_8i_FROM_16i
(
S3
,
S3
);
EIGEN_EXTRACT_8i_FROM_16i
(
S4
,
S4
);
EIGEN_EXTRACT_8i_FROM_16i
(
S5
,
S5
);
EIGEN_EXTRACT_8i_FROM_16i
(
S6
,
S6
);
EIGEN_EXTRACT_8i_FROM_16i
(
S7
,
S7
);
EIGEN_EXTRACT_8i_FROM_16i
(
S8
,
S8
);
EIGEN_EXTRACT_8i_FROM_16i
(
S9
,
S9
);
EIGEN_EXTRACT_8i_FROM_16i
(
S10
,
S10
);
EIGEN_EXTRACT_8i_FROM_16i
(
S11
,
S11
);
EIGEN_EXTRACT_8i_FROM_16i
(
S12
,
S12
);
EIGEN_EXTRACT_8i_FROM_16i
(
S13
,
S13
);
EIGEN_EXTRACT_8i_FROM_16i
(
S14
,
S14
);
EIGEN_EXTRACT_8i_FROM_16i
(
S15
,
S15
);
PacketBlock
<
Packet8i
,
32
>
tmp
;
tmp
.
packet
[
0
]
=
_mm256_permute2f128_si256
(
S0_0
,
S4_0
,
0x20
);
tmp
.
packet
[
1
]
=
_mm256_permute2f128_si256
(
S1_0
,
S5_0
,
0x20
);
tmp
.
packet
[
2
]
=
_mm256_permute2f128_si256
(
S2_0
,
S6_0
,
0x20
);
tmp
.
packet
[
3
]
=
_mm256_permute2f128_si256
(
S3_0
,
S7_0
,
0x20
);
tmp
.
packet
[
4
]
=
_mm256_permute2f128_si256
(
S0_0
,
S4_0
,
0x31
);
tmp
.
packet
[
5
]
=
_mm256_permute2f128_si256
(
S1_0
,
S5_0
,
0x31
);
tmp
.
packet
[
6
]
=
_mm256_permute2f128_si256
(
S2_0
,
S6_0
,
0x31
);
tmp
.
packet
[
7
]
=
_mm256_permute2f128_si256
(
S3_0
,
S7_0
,
0x31
);
tmp
.
packet
[
8
]
=
_mm256_permute2f128_si256
(
S0_1
,
S4_1
,
0x20
);
tmp
.
packet
[
9
]
=
_mm256_permute2f128_si256
(
S1_1
,
S5_1
,
0x20
);
tmp
.
packet
[
10
]
=
_mm256_permute2f128_si256
(
S2_1
,
S6_1
,
0x20
);
tmp
.
packet
[
11
]
=
_mm256_permute2f128_si256
(
S3_1
,
S7_1
,
0x20
);
tmp
.
packet
[
12
]
=
_mm256_permute2f128_si256
(
S0_1
,
S4_1
,
0x31
);
tmp
.
packet
[
13
]
=
_mm256_permute2f128_si256
(
S1_1
,
S5_1
,
0x31
);
tmp
.
packet
[
14
]
=
_mm256_permute2f128_si256
(
S2_1
,
S6_1
,
0x31
);
tmp
.
packet
[
15
]
=
_mm256_permute2f128_si256
(
S3_1
,
S7_1
,
0x31
);
// Second set of _m256 outputs
tmp
.
packet
[
16
]
=
_mm256_permute2f128_si256
(
S8_0
,
S12_0
,
0x20
);
tmp
.
packet
[
17
]
=
_mm256_permute2f128_si256
(
S9_0
,
S13_0
,
0x20
);
tmp
.
packet
[
18
]
=
_mm256_permute2f128_si256
(
S10_0
,
S14_0
,
0x20
);
tmp
.
packet
[
19
]
=
_mm256_permute2f128_si256
(
S11_0
,
S15_0
,
0x20
);
tmp
.
packet
[
20
]
=
_mm256_permute2f128_si256
(
S8_0
,
S12_0
,
0x31
);
tmp
.
packet
[
21
]
=
_mm256_permute2f128_si256
(
S9_0
,
S13_0
,
0x31
);
tmp
.
packet
[
22
]
=
_mm256_permute2f128_si256
(
S10_0
,
S14_0
,
0x31
);
tmp
.
packet
[
23
]
=
_mm256_permute2f128_si256
(
S11_0
,
S15_0
,
0x31
);
tmp
.
packet
[
24
]
=
_mm256_permute2f128_si256
(
S8_1
,
S12_1
,
0x20
);
tmp
.
packet
[
25
]
=
_mm256_permute2f128_si256
(
S9_1
,
S13_1
,
0x20
);
tmp
.
packet
[
26
]
=
_mm256_permute2f128_si256
(
S10_1
,
S14_1
,
0x20
);
tmp
.
packet
[
27
]
=
_mm256_permute2f128_si256
(
S11_1
,
S15_1
,
0x20
);
tmp
.
packet
[
28
]
=
_mm256_permute2f128_si256
(
S8_1
,
S12_1
,
0x31
);
tmp
.
packet
[
29
]
=
_mm256_permute2f128_si256
(
S9_1
,
S13_1
,
0x31
);
tmp
.
packet
[
30
]
=
_mm256_permute2f128_si256
(
S10_1
,
S14_1
,
0x31
);
tmp
.
packet
[
31
]
=
_mm256_permute2f128_si256
(
S11_1
,
S15_1
,
0x31
);
// Pack them into the output
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
0
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
1
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
2
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
3
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
4
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
5
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
6
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
7
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
8
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
9
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
10
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
11
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
12
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
13
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
14
,
16
);
PACK_OUTPUT_I32
(
kernel
.
packet
,
tmp
.
packet
,
15
,
16
);
}
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet16i
,
4
>&
kernel
)
{
__m512i
T0
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512i
T1
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512i
T2
=
_mm512_unpacklo_epi32
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512i
T3
=
_mm512_unpackhi_epi32
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512i
S0
=
SHUFFLE_EPI32
(
T0
,
T2
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S1
=
SHUFFLE_EPI32
(
T0
,
T2
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
__m512i
S2
=
SHUFFLE_EPI32
(
T1
,
T3
,
_MM_SHUFFLE
(
1
,
0
,
1
,
0
));
__m512i
S3
=
SHUFFLE_EPI32
(
T1
,
T3
,
_MM_SHUFFLE
(
3
,
2
,
3
,
2
));
EIGEN_EXTRACT_8i_FROM_16i
(
S0
,
S0
);
EIGEN_EXTRACT_8i_FROM_16i
(
S1
,
S1
);
EIGEN_EXTRACT_8i_FROM_16i
(
S2
,
S2
);
EIGEN_EXTRACT_8i_FROM_16i
(
S3
,
S3
);
PacketBlock
<
Packet8i
,
8
>
tmp
;
tmp
.
packet
[
0
]
=
_mm256_permute2f128_si256
(
S0_0
,
S1_0
,
0x20
);
tmp
.
packet
[
1
]
=
_mm256_permute2f128_si256
(
S2_0
,
S3_0
,
0x20
);
tmp
.
packet
[
2
]
=
_mm256_permute2f128_si256
(
S0_0
,
S1_0
,
0x31
);
tmp
.
packet
[
3
]
=
_mm256_permute2f128_si256
(
S2_0
,
S3_0
,
0x31
);
tmp
.
packet
[
4
]
=
_mm256_permute2f128_si256
(
S0_1
,
S1_1
,
0x20
);
tmp
.
packet
[
5
]
=
_mm256_permute2f128_si256
(
S2_1
,
S3_1
,
0x20
);
tmp
.
packet
[
6
]
=
_mm256_permute2f128_si256
(
S0_1
,
S1_1
,
0x31
);
tmp
.
packet
[
7
]
=
_mm256_permute2f128_si256
(
S2_1
,
S3_1
,
0x31
);
PACK_OUTPUT_I32_2
(
kernel
.
packet
,
tmp
.
packet
,
0
,
1
);
PACK_OUTPUT_I32_2
(
kernel
.
packet
,
tmp
.
packet
,
1
,
1
);
PACK_OUTPUT_I32_2
(
kernel
.
packet
,
tmp
.
packet
,
2
,
1
);
PACK_OUTPUT_I32_2
(
kernel
.
packet
,
tmp
.
packet
,
3
,
1
);
}
template
<
size_t
N
>
EIGEN_STRONG_INLINE
int
avx512_blend_mask
(
const
Selector
<
N
>&
ifPacket
)
{
alignas
(
__m128i
)
uint8_t
aux
[
sizeof
(
__m128i
)];
for
(
size_t
i
=
0
;
i
<
N
;
i
++
)
aux
[
i
]
=
static_cast
<
uint8_t
>
(
ifPacket
.
select
[
i
]);
__m128i
paux
=
_mm_sub_epi8
(
_mm_setzero_si128
(),
_mm_load_si128
(
reinterpret_cast
<
const
__m128i
*>
(
aux
)));
return
_mm_movemask_epi8
(
paux
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pblend
(
const
Selector
<
16
>&
ifPacket
,
const
Packet16f
&
thenPacket
,
const
Packet16f
&
elsePacket
)
{
__mmask16
m
=
avx512_blend_mask
(
ifPacket
);
return
_mm512_mask_blend_ps
(
m
,
elsePacket
,
thenPacket
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pblend
(
const
Selector
<
8
>&
ifPacket
,
const
Packet8d
&
thenPacket
,
const
Packet8d
&
elsePacket
)
{
__mmask8
m
=
avx512_blend_mask
(
ifPacket
);
return
_mm512_mask_blend_pd
(
m
,
elsePacket
,
thenPacket
);
}
// Packet math for Eigen::half
#ifndef EIGEN_VECTORIZE_AVX512FP16
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pset1
<
Packet16h
>
(
const
Eigen
::
half
&
from
)
{
return
_mm256_set1_epi16
(
from
.
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Eigen
::
half
pfirst
<
Packet16h
>
(
const
Packet16h
&
from
)
{
return
half_impl
::
raw_uint16_to_half
(
static_cast
<
unsigned
short
>
(
_mm256_extract_epi16
(
from
,
0
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pload
<
Packet16h
>
(
const
Eigen
::
half
*
from
)
{
return
_mm256_load_si256
(
reinterpret_cast
<
const
__m256i
*>
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
ploadu
<
Packet16h
>
(
const
Eigen
::
half
*
from
)
{
return
_mm256_loadu_si256
(
reinterpret_cast
<
const
__m256i
*>
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
half
>
(
Eigen
::
half
*
to
,
const
Packet16h
&
from
)
{
// (void*) -> workaround clang warning:
// cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
EIGEN_DEBUG_ALIGNED_STORE
_mm256_store_si256
((
__m256i
*
)(
void
*
)
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
half
>
(
Eigen
::
half
*
to
,
const
Packet16h
&
from
)
{
// (void*) -> workaround clang warning:
// cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_si256
((
__m256i
*
)(
void
*
)
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
ploaddup
<
Packet16h
>
(
const
Eigen
::
half
*
from
)
{
unsigned
short
a
=
from
[
0
].
x
;
unsigned
short
b
=
from
[
1
].
x
;
unsigned
short
c
=
from
[
2
].
x
;
unsigned
short
d
=
from
[
3
].
x
;
unsigned
short
e
=
from
[
4
].
x
;
unsigned
short
f
=
from
[
5
].
x
;
unsigned
short
g
=
from
[
6
].
x
;
unsigned
short
h
=
from
[
7
].
x
;
return
_mm256_set_epi16
(
h
,
h
,
g
,
g
,
f
,
f
,
e
,
e
,
d
,
d
,
c
,
c
,
b
,
b
,
a
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
ploadquad
(
const
Eigen
::
half
*
from
)
{
unsigned
short
a
=
from
[
0
].
x
;
unsigned
short
b
=
from
[
1
].
x
;
unsigned
short
c
=
from
[
2
].
x
;
unsigned
short
d
=
from
[
3
].
x
;
return
_mm256_set_epi16
(
d
,
d
,
d
,
d
,
c
,
c
,
c
,
c
,
b
,
b
,
b
,
b
,
a
,
a
,
a
,
a
);
}
EIGEN_STRONG_INLINE
Packet16f
half2float
(
const
Packet16h
&
a
)
{
return
_mm512_cvtph_ps
(
a
);
}
EIGEN_STRONG_INLINE
Packet16h
float2half
(
const
Packet16f
&
a
)
{
return
_mm512_cvtps_ph
(
a
,
_MM_FROUND_TO_NEAREST_INT
|
_MM_FROUND_NO_EXC
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
ptrue
(
const
Packet16h
&
a
)
{
return
Packet16h
(
ptrue
(
Packet8i
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pabs
(
const
Packet16h
&
a
)
{
const
__m256i
sign_mask
=
_mm256_set1_epi16
(
static_cast
<
numext
::
uint16_t
>
(
0x8000
));
return
_mm256_andnot_si256
(
sign_mask
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmin
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
float2half
(
pmin
<
Packet16f
>
(
half2float
(
a
),
half2float
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmax
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
float2half
(
pmax
<
Packet16f
>
(
half2float
(
a
),
half2float
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
plset
<
Packet16h
>
(
const
half
&
a
)
{
return
float2half
(
plset
<
Packet16f
>
(
static_cast
<
float
>
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
por
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
// in some cases Packet8i is a wrapper around __m256i, so we need to
// cast to Packet8i to call the correct overload.
return
Packet16h
(
por
(
Packet8i
(
a
),
Packet8i
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pxor
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
Packet16h
(
pxor
(
Packet8i
(
a
),
Packet8i
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pand
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
Packet16h
(
pand
(
Packet8i
(
a
),
Packet8i
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pandnot
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
Packet16h
(
pandnot
(
Packet8i
(
a
),
Packet8i
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pselect
(
const
Packet16h
&
mask
,
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_blendv_epi8
(
b
,
a
,
mask
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pround
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
float2half
(
pround
<
Packet16f
>
(
half2float
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
print
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
float2half
(
print
<
Packet16f
>
(
half2float
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pceil
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
float2half
(
pceil
<
Packet16f
>
(
half2float
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pfloor
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
float2half
(
pfloor
<
Packet16f
>
(
half2float
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
ptrunc
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
float2half
(
ptrunc
<
Packet16f
>
(
half2float
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcmp_eq
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
Packet16f
af
=
half2float
(
a
);
Packet16f
bf
=
half2float
(
b
);
return
Pack32To16
(
pcmp_eq
(
af
,
bf
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcmp_le
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
Pack32To16
(
pcmp_le
(
half2float
(
a
),
half2float
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcmp_lt
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
Pack32To16
(
pcmp_lt
(
half2float
(
a
),
half2float
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcmp_lt_or_nan
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
Pack32To16
(
pcmp_lt_or_nan
(
half2float
(
a
),
half2float
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pconj
(
const
Packet16h
&
a
)
{
return
a
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pnegate
(
const
Packet16h
&
a
)
{
Packet16h
sign_mask
=
_mm256_set1_epi16
(
static_cast
<
unsigned
short
>
(
0x8000
));
return
_mm256_xor_si256
(
a
,
sign_mask
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
padd
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
Packet16f
af
=
half2float
(
a
);
Packet16f
bf
=
half2float
(
b
);
Packet16f
rf
=
padd
(
af
,
bf
);
return
float2half
(
rf
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
psub
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
Packet16f
af
=
half2float
(
a
);
Packet16f
bf
=
half2float
(
b
);
Packet16f
rf
=
psub
(
af
,
bf
);
return
float2half
(
rf
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmul
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
Packet16f
af
=
half2float
(
a
);
Packet16f
bf
=
half2float
(
b
);
Packet16f
rf
=
pmul
(
af
,
bf
);
return
float2half
(
rf
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pdiv
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
Packet16f
af
=
half2float
(
a
);
Packet16f
bf
=
half2float
(
b
);
Packet16f
rf
=
pdiv
(
af
,
bf
);
return
float2half
(
rf
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmadd
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
,
const
Packet16h
&
c
)
{
return
float2half
(
pmadd
(
half2float
(
a
),
half2float
(
b
),
half2float
(
c
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmsub
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
,
const
Packet16h
&
c
)
{
return
float2half
(
pmsub
(
half2float
(
a
),
half2float
(
b
),
half2float
(
c
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pnmadd
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
,
const
Packet16h
&
c
)
{
return
float2half
(
pnmadd
(
half2float
(
a
),
half2float
(
b
),
half2float
(
c
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pnmsub
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
,
const
Packet16h
&
c
)
{
return
float2half
(
pnmsub
(
half2float
(
a
),
half2float
(
b
),
half2float
(
c
)));
}
template
<
>
EIGEN_STRONG_INLINE
half
predux
<
Packet16h
>
(
const
Packet16h
&
from
)
{
Packet16f
from_float
=
half2float
(
from
);
return
half
(
predux
(
from_float
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
predux_half_dowto4
<
Packet16h
>
(
const
Packet16h
&
a
)
{
Packet8h
lane0
=
_mm256_extractf128_si256
(
a
,
0
);
Packet8h
lane1
=
_mm256_extractf128_si256
(
a
,
1
);
return
padd
<
Packet8h
>
(
lane0
,
lane1
);
}
template
<
>
EIGEN_STRONG_INLINE
Eigen
::
half
predux_max
<
Packet16h
>
(
const
Packet16h
&
a
)
{
Packet16f
af
=
half2float
(
a
);
float
reduced
=
predux_max
<
Packet16f
>
(
af
);
return
Eigen
::
half
(
reduced
);
}
template
<
>
EIGEN_STRONG_INLINE
Eigen
::
half
predux_min
<
Packet16h
>
(
const
Packet16h
&
a
)
{
Packet16f
af
=
half2float
(
a
);
float
reduced
=
predux_min
<
Packet16f
>
(
af
);
return
Eigen
::
half
(
reduced
);
}
template
<
>
EIGEN_STRONG_INLINE
half
predux_mul
<
Packet16h
>
(
const
Packet16h
&
from
)
{
Packet16f
from_float
=
half2float
(
from
);
return
half
(
predux_mul
(
from_float
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
preverse
(
const
Packet16h
&
a
)
{
__m128i
m
=
_mm_setr_epi8
(
14
,
15
,
12
,
13
,
10
,
11
,
8
,
9
,
6
,
7
,
4
,
5
,
2
,
3
,
0
,
1
);
return
_mm256_insertf128_si256
(
_mm256_castsi128_si256
(
_mm_shuffle_epi8
(
_mm256_extractf128_si256
(
a
,
1
),
m
)),
_mm_shuffle_epi8
(
_mm256_extractf128_si256
(
a
,
0
),
m
),
1
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pgather
<
Eigen
::
half
,
Packet16h
>
(
const
Eigen
::
half
*
from
,
Index
stride
)
{
return
_mm256_set_epi16
(
from
[
15
*
stride
].
x
,
from
[
14
*
stride
].
x
,
from
[
13
*
stride
].
x
,
from
[
12
*
stride
].
x
,
from
[
11
*
stride
].
x
,
from
[
10
*
stride
].
x
,
from
[
9
*
stride
].
x
,
from
[
8
*
stride
].
x
,
from
[
7
*
stride
].
x
,
from
[
6
*
stride
].
x
,
from
[
5
*
stride
].
x
,
from
[
4
*
stride
].
x
,
from
[
3
*
stride
].
x
,
from
[
2
*
stride
].
x
,
from
[
1
*
stride
].
x
,
from
[
0
*
stride
].
x
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pscatter
<
half
,
Packet16h
>
(
half
*
to
,
const
Packet16h
&
from
,
Index
stride
)
{
EIGEN_ALIGN64
half
aux
[
16
];
pstore
(
aux
,
from
);
to
[
stride
*
0
]
=
aux
[
0
];
to
[
stride
*
1
]
=
aux
[
1
];
to
[
stride
*
2
]
=
aux
[
2
];
to
[
stride
*
3
]
=
aux
[
3
];
to
[
stride
*
4
]
=
aux
[
4
];
to
[
stride
*
5
]
=
aux
[
5
];
to
[
stride
*
6
]
=
aux
[
6
];
to
[
stride
*
7
]
=
aux
[
7
];
to
[
stride
*
8
]
=
aux
[
8
];
to
[
stride
*
9
]
=
aux
[
9
];
to
[
stride
*
10
]
=
aux
[
10
];
to
[
stride
*
11
]
=
aux
[
11
];
to
[
stride
*
12
]
=
aux
[
12
];
to
[
stride
*
13
]
=
aux
[
13
];
to
[
stride
*
14
]
=
aux
[
14
];
to
[
stride
*
15
]
=
aux
[
15
];
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet16h
,
16
>&
kernel
)
{
__m256i
a
=
kernel
.
packet
[
0
];
__m256i
b
=
kernel
.
packet
[
1
];
__m256i
c
=
kernel
.
packet
[
2
];
__m256i
d
=
kernel
.
packet
[
3
];
__m256i
e
=
kernel
.
packet
[
4
];
__m256i
f
=
kernel
.
packet
[
5
];
__m256i
g
=
kernel
.
packet
[
6
];
__m256i
h
=
kernel
.
packet
[
7
];
__m256i
i
=
kernel
.
packet
[
8
];
__m256i
j
=
kernel
.
packet
[
9
];
__m256i
k
=
kernel
.
packet
[
10
];
__m256i
l
=
kernel
.
packet
[
11
];
__m256i
m
=
kernel
.
packet
[
12
];
__m256i
n
=
kernel
.
packet
[
13
];
__m256i
o
=
kernel
.
packet
[
14
];
__m256i
p
=
kernel
.
packet
[
15
];
__m256i
ab_07
=
_mm256_unpacklo_epi16
(
a
,
b
);
__m256i
cd_07
=
_mm256_unpacklo_epi16
(
c
,
d
);
__m256i
ef_07
=
_mm256_unpacklo_epi16
(
e
,
f
);
__m256i
gh_07
=
_mm256_unpacklo_epi16
(
g
,
h
);
__m256i
ij_07
=
_mm256_unpacklo_epi16
(
i
,
j
);
__m256i
kl_07
=
_mm256_unpacklo_epi16
(
k
,
l
);
__m256i
mn_07
=
_mm256_unpacklo_epi16
(
m
,
n
);
__m256i
op_07
=
_mm256_unpacklo_epi16
(
o
,
p
);
__m256i
ab_8f
=
_mm256_unpackhi_epi16
(
a
,
b
);
__m256i
cd_8f
=
_mm256_unpackhi_epi16
(
c
,
d
);
__m256i
ef_8f
=
_mm256_unpackhi_epi16
(
e
,
f
);
__m256i
gh_8f
=
_mm256_unpackhi_epi16
(
g
,
h
);
__m256i
ij_8f
=
_mm256_unpackhi_epi16
(
i
,
j
);
__m256i
kl_8f
=
_mm256_unpackhi_epi16
(
k
,
l
);
__m256i
mn_8f
=
_mm256_unpackhi_epi16
(
m
,
n
);
__m256i
op_8f
=
_mm256_unpackhi_epi16
(
o
,
p
);
__m256i
abcd_03
=
_mm256_unpacklo_epi32
(
ab_07
,
cd_07
);
__m256i
abcd_47
=
_mm256_unpackhi_epi32
(
ab_07
,
cd_07
);
__m256i
efgh_03
=
_mm256_unpacklo_epi32
(
ef_07
,
gh_07
);
__m256i
efgh_47
=
_mm256_unpackhi_epi32
(
ef_07
,
gh_07
);
__m256i
ijkl_03
=
_mm256_unpacklo_epi32
(
ij_07
,
kl_07
);
__m256i
ijkl_47
=
_mm256_unpackhi_epi32
(
ij_07
,
kl_07
);
__m256i
mnop_03
=
_mm256_unpacklo_epi32
(
mn_07
,
op_07
);
__m256i
mnop_47
=
_mm256_unpackhi_epi32
(
mn_07
,
op_07
);
__m256i
abcd_8b
=
_mm256_unpacklo_epi32
(
ab_8f
,
cd_8f
);
__m256i
abcd_cf
=
_mm256_unpackhi_epi32
(
ab_8f
,
cd_8f
);
__m256i
efgh_8b
=
_mm256_unpacklo_epi32
(
ef_8f
,
gh_8f
);
__m256i
efgh_cf
=
_mm256_unpackhi_epi32
(
ef_8f
,
gh_8f
);
__m256i
ijkl_8b
=
_mm256_unpacklo_epi32
(
ij_8f
,
kl_8f
);
__m256i
ijkl_cf
=
_mm256_unpackhi_epi32
(
ij_8f
,
kl_8f
);
__m256i
mnop_8b
=
_mm256_unpacklo_epi32
(
mn_8f
,
op_8f
);
__m256i
mnop_cf
=
_mm256_unpackhi_epi32
(
mn_8f
,
op_8f
);
__m256i
abcdefgh_01
=
_mm256_unpacklo_epi64
(
abcd_03
,
efgh_03
);
__m256i
abcdefgh_23
=
_mm256_unpackhi_epi64
(
abcd_03
,
efgh_03
);
__m256i
ijklmnop_01
=
_mm256_unpacklo_epi64
(
ijkl_03
,
mnop_03
);
__m256i
ijklmnop_23
=
_mm256_unpackhi_epi64
(
ijkl_03
,
mnop_03
);
__m256i
abcdefgh_45
=
_mm256_unpacklo_epi64
(
abcd_47
,
efgh_47
);
__m256i
abcdefgh_67
=
_mm256_unpackhi_epi64
(
abcd_47
,
efgh_47
);
__m256i
ijklmnop_45
=
_mm256_unpacklo_epi64
(
ijkl_47
,
mnop_47
);
__m256i
ijklmnop_67
=
_mm256_unpackhi_epi64
(
ijkl_47
,
mnop_47
);
__m256i
abcdefgh_89
=
_mm256_unpacklo_epi64
(
abcd_8b
,
efgh_8b
);
__m256i
abcdefgh_ab
=
_mm256_unpackhi_epi64
(
abcd_8b
,
efgh_8b
);
__m256i
ijklmnop_89
=
_mm256_unpacklo_epi64
(
ijkl_8b
,
mnop_8b
);
__m256i
ijklmnop_ab
=
_mm256_unpackhi_epi64
(
ijkl_8b
,
mnop_8b
);
__m256i
abcdefgh_cd
=
_mm256_unpacklo_epi64
(
abcd_cf
,
efgh_cf
);
__m256i
abcdefgh_ef
=
_mm256_unpackhi_epi64
(
abcd_cf
,
efgh_cf
);
__m256i
ijklmnop_cd
=
_mm256_unpacklo_epi64
(
ijkl_cf
,
mnop_cf
);
__m256i
ijklmnop_ef
=
_mm256_unpackhi_epi64
(
ijkl_cf
,
mnop_cf
);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
__m256i
a_p_0
=
_mm256_permute2x128_si256
(
abcdefgh_01
,
ijklmnop_01
,
0x20
);
__m256i
a_p_1
=
_mm256_permute2x128_si256
(
abcdefgh_23
,
ijklmnop_23
,
0x20
);
__m256i
a_p_2
=
_mm256_permute2x128_si256
(
abcdefgh_45
,
ijklmnop_45
,
0x20
);
__m256i
a_p_3
=
_mm256_permute2x128_si256
(
abcdefgh_67
,
ijklmnop_67
,
0x20
);
__m256i
a_p_4
=
_mm256_permute2x128_si256
(
abcdefgh_89
,
ijklmnop_89
,
0x20
);
__m256i
a_p_5
=
_mm256_permute2x128_si256
(
abcdefgh_ab
,
ijklmnop_ab
,
0x20
);
__m256i
a_p_6
=
_mm256_permute2x128_si256
(
abcdefgh_cd
,
ijklmnop_cd
,
0x20
);
__m256i
a_p_7
=
_mm256_permute2x128_si256
(
abcdefgh_ef
,
ijklmnop_ef
,
0x20
);
__m256i
a_p_8
=
_mm256_permute2x128_si256
(
abcdefgh_01
,
ijklmnop_01
,
0x31
);
__m256i
a_p_9
=
_mm256_permute2x128_si256
(
abcdefgh_23
,
ijklmnop_23
,
0x31
);
__m256i
a_p_a
=
_mm256_permute2x128_si256
(
abcdefgh_45
,
ijklmnop_45
,
0x31
);
__m256i
a_p_b
=
_mm256_permute2x128_si256
(
abcdefgh_67
,
ijklmnop_67
,
0x31
);
__m256i
a_p_c
=
_mm256_permute2x128_si256
(
abcdefgh_89
,
ijklmnop_89
,
0x31
);
__m256i
a_p_d
=
_mm256_permute2x128_si256
(
abcdefgh_ab
,
ijklmnop_ab
,
0x31
);
__m256i
a_p_e
=
_mm256_permute2x128_si256
(
abcdefgh_cd
,
ijklmnop_cd
,
0x31
);
__m256i
a_p_f
=
_mm256_permute2x128_si256
(
abcdefgh_ef
,
ijklmnop_ef
,
0x31
);
kernel
.
packet
[
0
]
=
a_p_0
;
kernel
.
packet
[
1
]
=
a_p_1
;
kernel
.
packet
[
2
]
=
a_p_2
;
kernel
.
packet
[
3
]
=
a_p_3
;
kernel
.
packet
[
4
]
=
a_p_4
;
kernel
.
packet
[
5
]
=
a_p_5
;
kernel
.
packet
[
6
]
=
a_p_6
;
kernel
.
packet
[
7
]
=
a_p_7
;
kernel
.
packet
[
8
]
=
a_p_8
;
kernel
.
packet
[
9
]
=
a_p_9
;
kernel
.
packet
[
10
]
=
a_p_a
;
kernel
.
packet
[
11
]
=
a_p_b
;
kernel
.
packet
[
12
]
=
a_p_c
;
kernel
.
packet
[
13
]
=
a_p_d
;
kernel
.
packet
[
14
]
=
a_p_e
;
kernel
.
packet
[
15
]
=
a_p_f
;
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet16h
,
8
>&
kernel
)
{
EIGEN_ALIGN64
half
in
[
8
][
16
];
pstore
<
half
>
(
in
[
0
],
kernel
.
packet
[
0
]);
pstore
<
half
>
(
in
[
1
],
kernel
.
packet
[
1
]);
pstore
<
half
>
(
in
[
2
],
kernel
.
packet
[
2
]);
pstore
<
half
>
(
in
[
3
],
kernel
.
packet
[
3
]);
pstore
<
half
>
(
in
[
4
],
kernel
.
packet
[
4
]);
pstore
<
half
>
(
in
[
5
],
kernel
.
packet
[
5
]);
pstore
<
half
>
(
in
[
6
],
kernel
.
packet
[
6
]);
pstore
<
half
>
(
in
[
7
],
kernel
.
packet
[
7
]);
EIGEN_ALIGN64
half
out
[
8
][
16
];
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
out
[
i
][
j
]
=
in
[
j
][
2
*
i
];
}
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
out
[
i
][
j
+
8
]
=
in
[
j
][
2
*
i
+
1
];
}
}
kernel
.
packet
[
0
]
=
pload
<
Packet16h
>
(
out
[
0
]);
kernel
.
packet
[
1
]
=
pload
<
Packet16h
>
(
out
[
1
]);
kernel
.
packet
[
2
]
=
pload
<
Packet16h
>
(
out
[
2
]);
kernel
.
packet
[
3
]
=
pload
<
Packet16h
>
(
out
[
3
]);
kernel
.
packet
[
4
]
=
pload
<
Packet16h
>
(
out
[
4
]);
kernel
.
packet
[
5
]
=
pload
<
Packet16h
>
(
out
[
5
]);
kernel
.
packet
[
6
]
=
pload
<
Packet16h
>
(
out
[
6
]);
kernel
.
packet
[
7
]
=
pload
<
Packet16h
>
(
out
[
7
]);
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet16h
,
4
>&
kernel
)
{
EIGEN_ALIGN64
half
in
[
4
][
16
];
pstore
<
half
>
(
in
[
0
],
kernel
.
packet
[
0
]);
pstore
<
half
>
(
in
[
1
],
kernel
.
packet
[
1
]);
pstore
<
half
>
(
in
[
2
],
kernel
.
packet
[
2
]);
pstore
<
half
>
(
in
[
3
],
kernel
.
packet
[
3
]);
EIGEN_ALIGN64
half
out
[
4
][
16
];
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
]
=
in
[
j
][
4
*
i
];
}
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
+
4
]
=
in
[
j
][
4
*
i
+
1
];
}
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
+
8
]
=
in
[
j
][
4
*
i
+
2
];
}
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
+
12
]
=
in
[
j
][
4
*
i
+
3
];
}
}
kernel
.
packet
[
0
]
=
pload
<
Packet16h
>
(
out
[
0
]);
kernel
.
packet
[
1
]
=
pload
<
Packet16h
>
(
out
[
1
]);
kernel
.
packet
[
2
]
=
pload
<
Packet16h
>
(
out
[
2
]);
kernel
.
packet
[
3
]
=
pload
<
Packet16h
>
(
out
[
3
]);
}
#endif // EIGEN_VECTORIZE_AVX512FP16
template
<
>
struct
is_arithmetic
<
Packet16bf
>
{
enum
{
value
=
true
};
};
template
<
>
struct
packet_traits
<
bfloat16
>
:
default_packet_traits
{
typedef
Packet16bf
type
;
typedef
Packet8bf
half
;
enum
{
Vectorizable
=
1
,
AlignedOnScalar
=
1
,
size
=
16
,
HasBlend
=
0
,
HasInsert
=
1
,
HasSin
=
EIGEN_FAST_MATH
,
HasCos
=
EIGEN_FAST_MATH
,
HasSqrt
=
1
,
HasRsqrt
=
1
,
#ifdef EIGEN_VECTORIZE_AVX512DQ
HasLog
=
1
,
// Currently fails test with bad accuracy.
HasLog1p
=
1
,
HasExpm1
=
1
,
HasNdtri
=
1
,
HasBessel
=
1
,
#endif
HasExp
=
1
,
HasTanh
=
EIGEN_FAST_MATH
,
HasErf
=
EIGEN_FAST_MATH
,
HasCmp
=
1
,
HasDiv
=
1
};
};
template
<
>
struct
unpacket_traits
<
Packet16bf
>
{
typedef
bfloat16
type
;
enum
{
size
=
16
,
alignment
=
Aligned32
,
vectorizable
=
true
,
masked_load_available
=
false
,
masked_store_available
=
false
};
typedef
Packet8bf
half
;
};
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pset1
<
Packet16bf
>
(
const
bfloat16
&
from
)
{
return
_mm256_set1_epi16
(
from
.
value
);
}
template
<
>
EIGEN_STRONG_INLINE
bfloat16
pfirst
<
Packet16bf
>
(
const
Packet16bf
&
from
)
{
bfloat16
t
;
t
.
value
=
static_cast
<
unsigned
short
>
(
_mm256_extract_epi16
(
from
,
0
));
return
t
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pload
<
Packet16bf
>
(
const
bfloat16
*
from
)
{
return
_mm256_load_si256
(
reinterpret_cast
<
const
__m256i
*>
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
ploadu
<
Packet16bf
>
(
const
bfloat16
*
from
)
{
return
_mm256_loadu_si256
(
reinterpret_cast
<
const
__m256i
*>
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
bfloat16
>
(
bfloat16
*
to
,
const
Packet16bf
&
from
)
{
EIGEN_DEBUG_ALIGNED_STORE
_mm256_store_si256
(
reinterpret_cast
<
__m256i
*>
(
to
),
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
bfloat16
>
(
bfloat16
*
to
,
const
Packet16bf
&
from
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_si256
(
reinterpret_cast
<
__m256i
*>
(
to
),
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
ploaddup
<
Packet16bf
>
(
const
bfloat16
*
from
)
{
unsigned
short
a
=
from
[
0
].
value
;
unsigned
short
b
=
from
[
1
].
value
;
unsigned
short
c
=
from
[
2
].
value
;
unsigned
short
d
=
from
[
3
].
value
;
unsigned
short
e
=
from
[
4
].
value
;
unsigned
short
f
=
from
[
5
].
value
;
unsigned
short
g
=
from
[
6
].
value
;
unsigned
short
h
=
from
[
7
].
value
;
return
_mm256_set_epi16
(
h
,
h
,
g
,
g
,
f
,
f
,
e
,
e
,
d
,
d
,
c
,
c
,
b
,
b
,
a
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
ploadquad
(
const
bfloat16
*
from
)
{
unsigned
short
a
=
from
[
0
].
value
;
unsigned
short
b
=
from
[
1
].
value
;
unsigned
short
c
=
from
[
2
].
value
;
unsigned
short
d
=
from
[
3
].
value
;
return
_mm256_set_epi16
(
d
,
d
,
d
,
d
,
c
,
c
,
c
,
c
,
b
,
b
,
b
,
b
,
a
,
a
,
a
,
a
);
}
EIGEN_STRONG_INLINE
Packet16f
Bf16ToF32
(
const
Packet16bf
&
a
)
{
return
_mm512_castsi512_ps
(
_mm512_slli_epi32
(
_mm512_cvtepu16_epi32
(
a
),
16
));
}
// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
EIGEN_STRONG_INLINE
Packet16bf
F32ToBf16
(
const
Packet16f
&
a
)
{
Packet16bf
r
;
#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_STRICT_AT_LEAST(10, 1, 0)
// Since GCC 10.1 supports avx512bf16 and C style explicit cast
// (C++ static_cast is not supported yet), do conversion via intrinsic
// and register path for performance.
r
=
(
__m256i
)(
_mm512_cvtneps_pbh
(
a
));
#else
__m512i
t
;
__m512i
input
=
_mm512_castps_si512
(
a
);
__m512i
nan
=
_mm512_set1_epi32
(
0x7fc0
);
// uint32_t lsb = (input >> 16) & 1;
t
=
_mm512_and_si512
(
_mm512_srli_epi32
(
input
,
16
),
_mm512_set1_epi32
(
1
));
// uint32_t rounding_bias = 0x7fff + lsb;
t
=
_mm512_add_epi32
(
t
,
_mm512_set1_epi32
(
0x7fff
));
// input += rounding_bias;
t
=
_mm512_add_epi32
(
t
,
input
);
// input = input >> 16;
t
=
_mm512_srli_epi32
(
t
,
16
);
// Check NaN before converting back to bf16
__mmask16
mask
=
_mm512_cmp_ps_mask
(
a
,
a
,
_CMP_ORD_Q
);
t
=
_mm512_mask_blend_epi32
(
mask
,
nan
,
t
);
// output.value = static_cast<uint16_t>(input);
r
=
_mm512_cvtepi32_epi16
(
t
);
#endif // EIGEN_VECTORIZE_AVX512BF16
return
r
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
ptrue
(
const
Packet16bf
&
a
)
{
return
Packet16bf
(
ptrue
<
Packet8i
>
(
Packet8i
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
por
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
Packet16bf
(
por
<
Packet8i
>
(
Packet8i
(
a
),
Packet8i
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pxor
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
Packet16bf
(
pxor
<
Packet8i
>
(
Packet8i
(
a
),
Packet8i
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pand
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
Packet16bf
(
pand
<
Packet8i
>
(
Packet8i
(
a
),
Packet8i
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pandnot
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
Packet16bf
(
pandnot
<
Packet8i
>
(
Packet8i
(
a
),
Packet8i
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pselect
(
const
Packet16bf
&
mask
,
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
// Input mask is expected to be all 0/1, handle it with 8-bit
// intrinsic for performance.
return
_mm256_blendv_epi8
(
b
,
a
,
mask
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pround
<
Packet16bf
>
(
const
Packet16bf
&
a
)
{
return
F32ToBf16
(
pround
<
Packet16f
>
(
Bf16ToF32
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
print
<
Packet16bf
>
(
const
Packet16bf
&
a
)
{
return
F32ToBf16
(
print
<
Packet16f
>
(
Bf16ToF32
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pceil
<
Packet16bf
>
(
const
Packet16bf
&
a
)
{
return
F32ToBf16
(
pceil
<
Packet16f
>
(
Bf16ToF32
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pfloor
<
Packet16bf
>
(
const
Packet16bf
&
a
)
{
return
F32ToBf16
(
pfloor
<
Packet16f
>
(
Bf16ToF32
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
ptrunc
<
Packet16bf
>
(
const
Packet16bf
&
a
)
{
return
F32ToBf16
(
ptrunc
<
Packet16f
>
(
Bf16ToF32
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pcmp_eq
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
Pack32To16
(
pcmp_eq
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pcmp_le
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
Pack32To16
(
pcmp_le
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pcmp_lt
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
Pack32To16
(
pcmp_lt
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pcmp_lt_or_nan
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
Pack32To16
(
pcmp_lt_or_nan
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pnegate
(
const
Packet16bf
&
a
)
{
Packet16bf
sign_mask
=
_mm256_set1_epi16
(
static_cast
<
unsigned
short
>
(
0x8000
));
return
_mm256_xor_si256
(
a
,
sign_mask
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pconj
(
const
Packet16bf
&
a
)
{
return
a
;
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pabs
(
const
Packet16bf
&
a
)
{
const
__m256i
sign_mask
=
_mm256_set1_epi16
(
static_cast
<
numext
::
uint16_t
>
(
0x8000
));
return
_mm256_andnot_si256
(
sign_mask
,
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
padd
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
F32ToBf16
(
padd
<
Packet16f
>
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
psub
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
F32ToBf16
(
psub
<
Packet16f
>
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pmul
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
F32ToBf16
(
pmul
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pmadd
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
,
const
Packet16bf
&
c
)
{
return
F32ToBf16
(
pmadd
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
),
Bf16ToF32
(
c
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pmsub
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
,
const
Packet16bf
&
c
)
{
return
F32ToBf16
(
pmsub
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
),
Bf16ToF32
(
c
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pnmadd
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
,
const
Packet16bf
&
c
)
{
return
F32ToBf16
(
pnmadd
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
),
Bf16ToF32
(
c
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pnmsub
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
,
const
Packet16bf
&
c
)
{
return
F32ToBf16
(
pnmsub
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
),
Bf16ToF32
(
c
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pdiv
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
F32ToBf16
(
pdiv
<
Packet16f
>
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pmin
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
F32ToBf16
(
pmin
<
Packet16f
>
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pmax
<
Packet16bf
>
(
const
Packet16bf
&
a
,
const
Packet16bf
&
b
)
{
return
F32ToBf16
(
pmax
<
Packet16f
>
(
Bf16ToF32
(
a
),
Bf16ToF32
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
plset
<
Packet16bf
>
(
const
bfloat16
&
a
)
{
return
F32ToBf16
(
plset
<
Packet16f
>
(
static_cast
<
float
>
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8bf
predux_half_dowto4
<
Packet16bf
>
(
const
Packet16bf
&
a
)
{
Packet8bf
lane0
=
_mm256_extractf128_si256
(
a
,
0
);
Packet8bf
lane1
=
_mm256_extractf128_si256
(
a
,
1
);
return
padd
<
Packet8bf
>
(
lane0
,
lane1
);
}
template
<
>
EIGEN_STRONG_INLINE
bfloat16
predux
<
Packet16bf
>
(
const
Packet16bf
&
p
)
{
return
static_cast
<
bfloat16
>
(
predux
<
Packet16f
>
(
Bf16ToF32
(
p
)));
}
template
<
>
EIGEN_STRONG_INLINE
bfloat16
predux_mul
<
Packet16bf
>
(
const
Packet16bf
&
from
)
{
return
static_cast
<
bfloat16
>
(
predux_mul
<
Packet16f
>
(
Bf16ToF32
(
from
)));
}
template
<
>
EIGEN_STRONG_INLINE
bfloat16
predux_min
<
Packet16bf
>
(
const
Packet16bf
&
from
)
{
return
static_cast
<
bfloat16
>
(
predux_min
<
Packet16f
>
(
Bf16ToF32
(
from
)));
}
template
<
>
EIGEN_STRONG_INLINE
bfloat16
predux_max
<
Packet16bf
>
(
const
Packet16bf
&
from
)
{
return
static_cast
<
bfloat16
>
(
predux_max
<
Packet16f
>
(
Bf16ToF32
(
from
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
preverse
(
const
Packet16bf
&
a
)
{
__m256i
m
=
_mm256_setr_epi8
(
14
,
15
,
12
,
13
,
10
,
11
,
8
,
9
,
6
,
7
,
4
,
5
,
2
,
3
,
0
,
1
,
14
,
15
,
12
,
13
,
10
,
11
,
8
,
9
,
6
,
7
,
4
,
5
,
2
,
3
,
0
,
1
);
Packet16bf
res
;
// Swap hi and lo first because shuffle is in 128-bit lanes.
res
=
_mm256_permute2x128_si256
(
a
,
a
,
1
);
// Shuffle 8-bit values in src within 2*128-bit lanes.
return
_mm256_shuffle_epi8
(
res
,
m
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pgather
<
bfloat16
,
Packet16bf
>
(
const
bfloat16
*
from
,
Index
stride
)
{
return
_mm256_set_epi16
(
from
[
15
*
stride
].
value
,
from
[
14
*
stride
].
value
,
from
[
13
*
stride
].
value
,
from
[
12
*
stride
].
value
,
from
[
11
*
stride
].
value
,
from
[
10
*
stride
].
value
,
from
[
9
*
stride
].
value
,
from
[
8
*
stride
].
value
,
from
[
7
*
stride
].
value
,
from
[
6
*
stride
].
value
,
from
[
5
*
stride
].
value
,
from
[
4
*
stride
].
value
,
from
[
3
*
stride
].
value
,
from
[
2
*
stride
].
value
,
from
[
1
*
stride
].
value
,
from
[
0
*
stride
].
value
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pscatter
<
bfloat16
,
Packet16bf
>
(
bfloat16
*
to
,
const
Packet16bf
&
from
,
Index
stride
)
{
EIGEN_ALIGN64
bfloat16
aux
[
16
];
pstore
(
aux
,
from
);
to
[
stride
*
0
]
=
aux
[
0
];
to
[
stride
*
1
]
=
aux
[
1
];
to
[
stride
*
2
]
=
aux
[
2
];
to
[
stride
*
3
]
=
aux
[
3
];
to
[
stride
*
4
]
=
aux
[
4
];
to
[
stride
*
5
]
=
aux
[
5
];
to
[
stride
*
6
]
=
aux
[
6
];
to
[
stride
*
7
]
=
aux
[
7
];
to
[
stride
*
8
]
=
aux
[
8
];
to
[
stride
*
9
]
=
aux
[
9
];
to
[
stride
*
10
]
=
aux
[
10
];
to
[
stride
*
11
]
=
aux
[
11
];
to
[
stride
*
12
]
=
aux
[
12
];
to
[
stride
*
13
]
=
aux
[
13
];
to
[
stride
*
14
]
=
aux
[
14
];
to
[
stride
*
15
]
=
aux
[
15
];
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet16bf
,
16
>&
kernel
)
{
__m256i
a
=
kernel
.
packet
[
0
];
__m256i
b
=
kernel
.
packet
[
1
];
__m256i
c
=
kernel
.
packet
[
2
];
__m256i
d
=
kernel
.
packet
[
3
];
__m256i
e
=
kernel
.
packet
[
4
];
__m256i
f
=
kernel
.
packet
[
5
];
__m256i
g
=
kernel
.
packet
[
6
];
__m256i
h
=
kernel
.
packet
[
7
];
__m256i
i
=
kernel
.
packet
[
8
];
__m256i
j
=
kernel
.
packet
[
9
];
__m256i
k
=
kernel
.
packet
[
10
];
__m256i
l
=
kernel
.
packet
[
11
];
__m256i
m
=
kernel
.
packet
[
12
];
__m256i
n
=
kernel
.
packet
[
13
];
__m256i
o
=
kernel
.
packet
[
14
];
__m256i
p
=
kernel
.
packet
[
15
];
__m256i
ab_07
=
_mm256_unpacklo_epi16
(
a
,
b
);
__m256i
cd_07
=
_mm256_unpacklo_epi16
(
c
,
d
);
__m256i
ef_07
=
_mm256_unpacklo_epi16
(
e
,
f
);
__m256i
gh_07
=
_mm256_unpacklo_epi16
(
g
,
h
);
__m256i
ij_07
=
_mm256_unpacklo_epi16
(
i
,
j
);
__m256i
kl_07
=
_mm256_unpacklo_epi16
(
k
,
l
);
__m256i
mn_07
=
_mm256_unpacklo_epi16
(
m
,
n
);
__m256i
op_07
=
_mm256_unpacklo_epi16
(
o
,
p
);
__m256i
ab_8f
=
_mm256_unpackhi_epi16
(
a
,
b
);
__m256i
cd_8f
=
_mm256_unpackhi_epi16
(
c
,
d
);
__m256i
ef_8f
=
_mm256_unpackhi_epi16
(
e
,
f
);
__m256i
gh_8f
=
_mm256_unpackhi_epi16
(
g
,
h
);
__m256i
ij_8f
=
_mm256_unpackhi_epi16
(
i
,
j
);
__m256i
kl_8f
=
_mm256_unpackhi_epi16
(
k
,
l
);
__m256i
mn_8f
=
_mm256_unpackhi_epi16
(
m
,
n
);
__m256i
op_8f
=
_mm256_unpackhi_epi16
(
o
,
p
);
__m256i
abcd_03
=
_mm256_unpacklo_epi32
(
ab_07
,
cd_07
);
__m256i
abcd_47
=
_mm256_unpackhi_epi32
(
ab_07
,
cd_07
);
__m256i
efgh_03
=
_mm256_unpacklo_epi32
(
ef_07
,
gh_07
);
__m256i
efgh_47
=
_mm256_unpackhi_epi32
(
ef_07
,
gh_07
);
__m256i
ijkl_03
=
_mm256_unpacklo_epi32
(
ij_07
,
kl_07
);
__m256i
ijkl_47
=
_mm256_unpackhi_epi32
(
ij_07
,
kl_07
);
__m256i
mnop_03
=
_mm256_unpacklo_epi32
(
mn_07
,
op_07
);
__m256i
mnop_47
=
_mm256_unpackhi_epi32
(
mn_07
,
op_07
);
__m256i
abcd_8b
=
_mm256_unpacklo_epi32
(
ab_8f
,
cd_8f
);
__m256i
abcd_cf
=
_mm256_unpackhi_epi32
(
ab_8f
,
cd_8f
);
__m256i
efgh_8b
=
_mm256_unpacklo_epi32
(
ef_8f
,
gh_8f
);
__m256i
efgh_cf
=
_mm256_unpackhi_epi32
(
ef_8f
,
gh_8f
);
__m256i
ijkl_8b
=
_mm256_unpacklo_epi32
(
ij_8f
,
kl_8f
);
__m256i
ijkl_cf
=
_mm256_unpackhi_epi32
(
ij_8f
,
kl_8f
);
__m256i
mnop_8b
=
_mm256_unpacklo_epi32
(
mn_8f
,
op_8f
);
__m256i
mnop_cf
=
_mm256_unpackhi_epi32
(
mn_8f
,
op_8f
);
__m256i
abcdefgh_01
=
_mm256_unpacklo_epi64
(
abcd_03
,
efgh_03
);
__m256i
abcdefgh_23
=
_mm256_unpackhi_epi64
(
abcd_03
,
efgh_03
);
__m256i
ijklmnop_01
=
_mm256_unpacklo_epi64
(
ijkl_03
,
mnop_03
);
__m256i
ijklmnop_23
=
_mm256_unpackhi_epi64
(
ijkl_03
,
mnop_03
);
__m256i
abcdefgh_45
=
_mm256_unpacklo_epi64
(
abcd_47
,
efgh_47
);
__m256i
abcdefgh_67
=
_mm256_unpackhi_epi64
(
abcd_47
,
efgh_47
);
__m256i
ijklmnop_45
=
_mm256_unpacklo_epi64
(
ijkl_47
,
mnop_47
);
__m256i
ijklmnop_67
=
_mm256_unpackhi_epi64
(
ijkl_47
,
mnop_47
);
__m256i
abcdefgh_89
=
_mm256_unpacklo_epi64
(
abcd_8b
,
efgh_8b
);
__m256i
abcdefgh_ab
=
_mm256_unpackhi_epi64
(
abcd_8b
,
efgh_8b
);
__m256i
ijklmnop_89
=
_mm256_unpacklo_epi64
(
ijkl_8b
,
mnop_8b
);
__m256i
ijklmnop_ab
=
_mm256_unpackhi_epi64
(
ijkl_8b
,
mnop_8b
);
__m256i
abcdefgh_cd
=
_mm256_unpacklo_epi64
(
abcd_cf
,
efgh_cf
);
__m256i
abcdefgh_ef
=
_mm256_unpackhi_epi64
(
abcd_cf
,
efgh_cf
);
__m256i
ijklmnop_cd
=
_mm256_unpacklo_epi64
(
ijkl_cf
,
mnop_cf
);
__m256i
ijklmnop_ef
=
_mm256_unpackhi_epi64
(
ijkl_cf
,
mnop_cf
);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
kernel
.
packet
[
0
]
=
_mm256_permute2x128_si256
(
abcdefgh_01
,
ijklmnop_01
,
0x20
);
kernel
.
packet
[
1
]
=
_mm256_permute2x128_si256
(
abcdefgh_23
,
ijklmnop_23
,
0x20
);
kernel
.
packet
[
2
]
=
_mm256_permute2x128_si256
(
abcdefgh_45
,
ijklmnop_45
,
0x20
);
kernel
.
packet
[
3
]
=
_mm256_permute2x128_si256
(
abcdefgh_67
,
ijklmnop_67
,
0x20
);
kernel
.
packet
[
4
]
=
_mm256_permute2x128_si256
(
abcdefgh_89
,
ijklmnop_89
,
0x20
);
kernel
.
packet
[
5
]
=
_mm256_permute2x128_si256
(
abcdefgh_ab
,
ijklmnop_ab
,
0x20
);
kernel
.
packet
[
6
]
=
_mm256_permute2x128_si256
(
abcdefgh_cd
,
ijklmnop_cd
,
0x20
);
kernel
.
packet
[
7
]
=
_mm256_permute2x128_si256
(
abcdefgh_ef
,
ijklmnop_ef
,
0x20
);
kernel
.
packet
[
8
]
=
_mm256_permute2x128_si256
(
abcdefgh_01
,
ijklmnop_01
,
0x31
);
kernel
.
packet
[
9
]
=
_mm256_permute2x128_si256
(
abcdefgh_23
,
ijklmnop_23
,
0x31
);
kernel
.
packet
[
10
]
=
_mm256_permute2x128_si256
(
abcdefgh_45
,
ijklmnop_45
,
0x31
);
kernel
.
packet
[
11
]
=
_mm256_permute2x128_si256
(
abcdefgh_67
,
ijklmnop_67
,
0x31
);
kernel
.
packet
[
12
]
=
_mm256_permute2x128_si256
(
abcdefgh_89
,
ijklmnop_89
,
0x31
);
kernel
.
packet
[
13
]
=
_mm256_permute2x128_si256
(
abcdefgh_ab
,
ijklmnop_ab
,
0x31
);
kernel
.
packet
[
14
]
=
_mm256_permute2x128_si256
(
abcdefgh_cd
,
ijklmnop_cd
,
0x31
);
kernel
.
packet
[
15
]
=
_mm256_permute2x128_si256
(
abcdefgh_ef
,
ijklmnop_ef
,
0x31
);
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet16bf
,
4
>&
kernel
)
{
__m256i
a
=
kernel
.
packet
[
0
];
__m256i
b
=
kernel
.
packet
[
1
];
__m256i
c
=
kernel
.
packet
[
2
];
__m256i
d
=
kernel
.
packet
[
3
];
__m256i
ab_07
=
_mm256_unpacklo_epi16
(
a
,
b
);
__m256i
cd_07
=
_mm256_unpacklo_epi16
(
c
,
d
);
__m256i
ab_8f
=
_mm256_unpackhi_epi16
(
a
,
b
);
__m256i
cd_8f
=
_mm256_unpackhi_epi16
(
c
,
d
);
__m256i
abcd_03
=
_mm256_unpacklo_epi32
(
ab_07
,
cd_07
);
__m256i
abcd_47
=
_mm256_unpackhi_epi32
(
ab_07
,
cd_07
);
__m256i
abcd_8b
=
_mm256_unpacklo_epi32
(
ab_8f
,
cd_8f
);
__m256i
abcd_cf
=
_mm256_unpackhi_epi32
(
ab_8f
,
cd_8f
);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
kernel
.
packet
[
0
]
=
_mm256_permute2x128_si256
(
abcd_03
,
abcd_47
,
0x20
);
kernel
.
packet
[
1
]
=
_mm256_permute2x128_si256
(
abcd_8b
,
abcd_cf
,
0x20
);
kernel
.
packet
[
2
]
=
_mm256_permute2x128_si256
(
abcd_03
,
abcd_47
,
0x31
);
kernel
.
packet
[
3
]
=
_mm256_permute2x128_si256
(
abcd_8b
,
abcd_cf
,
0x31
);
}
// Minimal implementation of 16-bit int packets for use in pfrexp, pldexp.
template
<
>
EIGEN_STRONG_INLINE
Packet32s
pset1
<
Packet32s
>
(
const
numext
::
int16_t
&
x
)
{
return
_mm512_set1_epi16
(
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16s
pset1
<
Packet16s
>
(
const
numext
::
int16_t
&
x
)
{
return
_mm256_set1_epi16
(
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8s
pset1
<
Packet8s
>
(
const
numext
::
int16_t
&
x
)
{
return
_mm_set1_epi16
(
x
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
numext
::
int16_t
,
Packet32s
>
(
numext
::
int16_t
*
out
,
const
Packet32s
&
x
)
{
EIGEN_DEBUG_ALIGNED_STORE
_mm512_store_epi32
(
out
,
x
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
numext
::
int16_t
,
Packet16s
>
(
numext
::
int16_t
*
out
,
const
Packet16s
&
x
)
{
EIGEN_DEBUG_ALIGNED_STORE
#if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
_mm256_store_epi32
(
out
,
x
);
#else
_mm256_store_si256
(
reinterpret_cast
<
__m256i
*>
(
out
),
x
);
#endif
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
numext
::
int16_t
,
Packet8s
>
(
numext
::
int16_t
*
out
,
const
Packet8s
&
x
)
{
EIGEN_DEBUG_ALIGNED_STORE
#if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
_mm256_store_epi32
(
out
,
x
);
#else
_mm_store_si128
(
reinterpret_cast
<
__m128i
*>
(
out
),
x
);
#endif
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
numext
::
int16_t
,
Packet32s
>
(
numext
::
int16_t
*
out
,
const
Packet32s
&
x
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm512_storeu_epi32
(
out
,
x
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
numext
::
int16_t
,
Packet16s
>
(
numext
::
int16_t
*
out
,
const
Packet16s
&
x
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_epi32
(
out
,
x
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
numext
::
int16_t
,
Packet8s
>
(
numext
::
int16_t
*
out
,
const
Packet8s
&
x
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm_storeu_epi32
(
out
,
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet32s
padd
(
const
Packet32s
&
a
,
const
Packet32s
&
b
)
{
return
_mm512_add_epi16
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16s
padd
(
const
Packet16s
&
a
,
const
Packet16s
&
b
)
{
return
_mm256_add_epi16
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8s
padd
(
const
Packet8s
&
a
,
const
Packet8s
&
b
)
{
return
_mm_add_epi16
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet32s
psub
(
const
Packet32s
&
a
,
const
Packet32s
&
b
)
{
return
_mm512_sub_epi16
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16s
psub
(
const
Packet16s
&
a
,
const
Packet16s
&
b
)
{
return
_mm256_sub_epi16
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8s
psub
(
const
Packet8s
&
a
,
const
Packet8s
&
b
)
{
return
_mm_sub_epi16
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet32s
pmul
(
const
Packet32s
&
a
,
const
Packet32s
&
b
)
{
return
_mm512_mullo_epi16
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16s
pmul
(
const
Packet16s
&
a
,
const
Packet16s
&
b
)
{
return
_mm256_mullo_epi16
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8s
pmul
(
const
Packet8s
&
a
,
const
Packet8s
&
b
)
{
return
_mm_mullo_epi16
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet32s
pnegate
(
const
Packet32s
&
a
)
{
return
_mm512_sub_epi16
(
_mm512_setzero_si512
(),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16s
pnegate
(
const
Packet16s
&
a
)
{
return
_mm256_sub_epi16
(
_mm256_setzero_si256
(),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8s
pnegate
(
const
Packet8s
&
a
)
{
return
_mm_sub_epi16
(
_mm_setzero_si128
(),
a
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet32s
parithmetic_shift_right
(
Packet32s
a
)
{
return
_mm512_srai_epi16
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet16s
parithmetic_shift_right
(
Packet16s
a
)
{
return
_mm256_srai_epi16
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet8s
parithmetic_shift_right
(
Packet8s
a
)
{
return
_mm_srai_epi16
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet32s
plogical_shift_left
(
Packet32s
a
)
{
return
_mm512_slli_epi16
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet16s
plogical_shift_left
(
Packet16s
a
)
{
return
_mm256_slli_epi16
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet8s
plogical_shift_left
(
Packet8s
a
)
{
return
_mm_slli_epi16
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet32s
plogical_shift_right
(
Packet32s
a
)
{
return
_mm512_srli_epi16
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet16s
plogical_shift_right
(
Packet16s
a
)
{
return
_mm256_srli_epi16
(
a
,
N
);
}
template
<
int
N
>
EIGEN_STRONG_INLINE
Packet8s
plogical_shift_right
(
Packet8s
a
)
{
return
_mm_srli_epi16
(
a
,
N
);
}
}
// end namespace internal
}
// end namespace Eigen
#endif // EIGEN_PACKET_MATH_AVX512_H
eigen-master/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
0 → 100644
View file @
266d4fd9
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_PACKET_MATH_FP16_AVX512_H
#define EIGEN_PACKET_MATH_FP16_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace
Eigen
{
namespace
internal
{
typedef
__m512h
Packet32h
;
typedef
__m256h
Packet16h
;
typedef
__m128h
Packet8h
;
template
<
>
struct
is_arithmetic
<
Packet8h
>
{
enum
{
value
=
true
};
};
template
<
>
struct
packet_traits
<
half
>
:
default_packet_traits
{
typedef
Packet32h
type
;
typedef
Packet16h
half
;
enum
{
Vectorizable
=
1
,
AlignedOnScalar
=
1
,
size
=
32
,
HasCmp
=
1
,
HasAdd
=
1
,
HasSub
=
1
,
HasMul
=
1
,
HasDiv
=
1
,
HasNegate
=
1
,
HasAbs
=
1
,
HasAbs2
=
0
,
HasMin
=
1
,
HasMax
=
1
,
HasConj
=
1
,
HasSetLinear
=
0
,
HasLog
=
1
,
HasLog1p
=
1
,
HasExp
=
1
,
HasExpm1
=
1
,
HasSqrt
=
1
,
HasRsqrt
=
1
,
// These ones should be implemented in future
HasBessel
=
0
,
HasNdtri
=
0
,
HasSin
=
EIGEN_FAST_MATH
,
HasCos
=
EIGEN_FAST_MATH
,
HasTanh
=
EIGEN_FAST_MATH
,
HasErf
=
0
,
// EIGEN_FAST_MATH,
HasBlend
=
0
};
};
template
<
>
struct
unpacket_traits
<
Packet32h
>
{
typedef
Eigen
::
half
type
;
typedef
Packet16h
half
;
typedef
Packet32s
integer_packet
;
enum
{
size
=
32
,
alignment
=
Aligned64
,
vectorizable
=
true
,
masked_load_available
=
false
,
masked_store_available
=
false
};
};
template
<
>
struct
unpacket_traits
<
Packet16h
>
{
typedef
Eigen
::
half
type
;
typedef
Packet8h
half
;
typedef
Packet16s
integer_packet
;
enum
{
size
=
16
,
alignment
=
Aligned32
,
vectorizable
=
true
,
masked_load_available
=
false
,
masked_store_available
=
false
};
};
template
<
>
struct
unpacket_traits
<
Packet8h
>
{
typedef
Eigen
::
half
type
;
typedef
Packet8h
half
;
typedef
Packet8s
integer_packet
;
enum
{
size
=
8
,
alignment
=
Aligned16
,
vectorizable
=
true
,
masked_load_available
=
false
,
masked_store_available
=
false
};
};
// Conversions
EIGEN_STRONG_INLINE
Packet16f
half2float
(
const
Packet16h
&
a
)
{
return
_mm512_cvtxph_ps
(
a
);
}
EIGEN_STRONG_INLINE
Packet8f
half2float
(
const
Packet8h
&
a
)
{
return
_mm256_cvtxph_ps
(
a
);
}
EIGEN_STRONG_INLINE
Packet16h
float2half
(
const
Packet16f
&
a
)
{
return
_mm512_cvtxps_ph
(
a
);
}
EIGEN_STRONG_INLINE
Packet8h
float2half
(
const
Packet8f
&
a
)
{
return
_mm256_cvtxps_ph
(
a
);
}
// Memory functions
// pset1
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pset1
<
Packet32h
>
(
const
Eigen
::
half
&
from
)
{
return
_mm512_set1_ph
(
from
.
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pset1
<
Packet16h
>
(
const
Eigen
::
half
&
from
)
{
return
_mm256_set1_ph
(
from
.
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pset1
<
Packet8h
>
(
const
Eigen
::
half
&
from
)
{
return
_mm_set1_ph
(
from
.
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pzero
(
const
Packet32h
&
/*a*/
)
{
return
_mm512_setzero_ph
();
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pzero
(
const
Packet16h
&
/*a*/
)
{
return
_mm256_setzero_ph
();
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pzero
(
const
Packet8h
&
/*a*/
)
{
return
_mm_setzero_ph
();
}
// pset1frombits
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pset1frombits
<
Packet32h
>
(
unsigned
short
from
)
{
return
_mm512_castsi512_ph
(
_mm512_set1_epi16
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pset1frombits
<
Packet16h
>
(
unsigned
short
from
)
{
return
_mm256_castsi256_ph
(
_mm256_set1_epi16
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pset1frombits
<
Packet8h
>
(
unsigned
short
from
)
{
return
_mm_castsi128_ph
(
_mm_set1_epi16
(
from
));
}
// pfirst
template
<
>
EIGEN_STRONG_INLINE
Eigen
::
half
pfirst
<
Packet32h
>
(
const
Packet32h
&
from
)
{
return
Eigen
::
half
(
_mm512_cvtsh_h
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
Eigen
::
half
pfirst
<
Packet16h
>
(
const
Packet16h
&
from
)
{
return
Eigen
::
half
(
_mm256_cvtsh_h
(
from
));
}
template
<
>
EIGEN_STRONG_INLINE
Eigen
::
half
pfirst
<
Packet8h
>
(
const
Packet8h
&
from
)
{
return
Eigen
::
half
(
_mm_cvtsh_h
(
from
));
}
// pload
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pload
<
Packet32h
>
(
const
Eigen
::
half
*
from
)
{
EIGEN_DEBUG_ALIGNED_LOAD
return
_mm512_load_ph
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pload
<
Packet16h
>
(
const
Eigen
::
half
*
from
)
{
EIGEN_DEBUG_ALIGNED_LOAD
return
_mm256_load_ph
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pload
<
Packet8h
>
(
const
Eigen
::
half
*
from
)
{
EIGEN_DEBUG_ALIGNED_LOAD
return
_mm_load_ph
(
from
);
}
// ploadu
template
<
>
EIGEN_STRONG_INLINE
Packet32h
ploadu
<
Packet32h
>
(
const
Eigen
::
half
*
from
)
{
EIGEN_DEBUG_UNALIGNED_LOAD
return
_mm512_loadu_ph
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
ploadu
<
Packet16h
>
(
const
Eigen
::
half
*
from
)
{
EIGEN_DEBUG_UNALIGNED_LOAD
return
_mm256_loadu_ph
(
from
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
ploadu
<
Packet8h
>
(
const
Eigen
::
half
*
from
)
{
EIGEN_DEBUG_UNALIGNED_LOAD
return
_mm_loadu_ph
(
from
);
}
// pstore
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
half
>
(
Eigen
::
half
*
to
,
const
Packet32h
&
from
)
{
EIGEN_DEBUG_ALIGNED_STORE
_mm512_store_ph
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
half
>
(
Eigen
::
half
*
to
,
const
Packet16h
&
from
)
{
EIGEN_DEBUG_ALIGNED_STORE
_mm256_store_ph
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstore
<
half
>
(
Eigen
::
half
*
to
,
const
Packet8h
&
from
)
{
EIGEN_DEBUG_ALIGNED_STORE
_mm_store_ph
(
to
,
from
);
}
// pstoreu
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
half
>
(
Eigen
::
half
*
to
,
const
Packet32h
&
from
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm512_storeu_ph
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
half
>
(
Eigen
::
half
*
to
,
const
Packet16h
&
from
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_ph
(
to
,
from
);
}
template
<
>
EIGEN_STRONG_INLINE
void
pstoreu
<
half
>
(
Eigen
::
half
*
to
,
const
Packet8h
&
from
)
{
EIGEN_DEBUG_UNALIGNED_STORE
_mm_storeu_ph
(
to
,
from
);
}
// ploaddup
template
<
>
EIGEN_STRONG_INLINE
Packet32h
ploaddup
<
Packet32h
>
(
const
Eigen
::
half
*
from
)
{
__m512h
a
=
_mm512_castph256_ph512
(
_mm256_loadu_ph
(
from
));
return
_mm512_permutexvar_ph
(
_mm512_set_epi16
(
15
,
15
,
14
,
14
,
13
,
13
,
12
,
12
,
11
,
11
,
10
,
10
,
9
,
9
,
8
,
8
,
7
,
7
,
6
,
6
,
5
,
5
,
4
,
4
,
3
,
3
,
2
,
2
,
1
,
1
,
0
,
0
),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
ploaddup
<
Packet16h
>
(
const
Eigen
::
half
*
from
)
{
__m256h
a
=
_mm256_castph128_ph256
(
_mm_loadu_ph
(
from
));
return
_mm256_permutexvar_ph
(
_mm256_set_epi16
(
7
,
7
,
6
,
6
,
5
,
5
,
4
,
4
,
3
,
3
,
2
,
2
,
1
,
1
,
0
,
0
),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
ploaddup
<
Packet8h
>
(
const
Eigen
::
half
*
from
)
{
return
_mm_set_ph
(
from
[
3
].
x
,
from
[
3
].
x
,
from
[
2
].
x
,
from
[
2
].
x
,
from
[
1
].
x
,
from
[
1
].
x
,
from
[
0
].
x
,
from
[
0
].
x
);
}
// ploadquad
template
<
>
EIGEN_STRONG_INLINE
Packet32h
ploadquad
<
Packet32h
>
(
const
Eigen
::
half
*
from
)
{
__m512h
a
=
_mm512_castph128_ph512
(
_mm_loadu_ph
(
from
));
return
_mm512_permutexvar_ph
(
_mm512_set_epi16
(
7
,
7
,
7
,
7
,
6
,
6
,
6
,
6
,
5
,
5
,
5
,
5
,
4
,
4
,
4
,
4
,
3
,
3
,
3
,
3
,
2
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
ploadquad
<
Packet16h
>
(
const
Eigen
::
half
*
from
)
{
return
_mm256_set_ph
(
from
[
3
].
x
,
from
[
3
].
x
,
from
[
3
].
x
,
from
[
3
].
x
,
from
[
2
].
x
,
from
[
2
].
x
,
from
[
2
].
x
,
from
[
2
].
x
,
from
[
1
].
x
,
from
[
1
].
x
,
from
[
1
].
x
,
from
[
1
].
x
,
from
[
0
].
x
,
from
[
0
].
x
,
from
[
0
].
x
,
from
[
0
].
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
ploadquad
<
Packet8h
>
(
const
Eigen
::
half
*
from
)
{
return
_mm_set_ph
(
from
[
1
].
x
,
from
[
1
].
x
,
from
[
1
].
x
,
from
[
1
].
x
,
from
[
0
].
x
,
from
[
0
].
x
,
from
[
0
].
x
,
from
[
0
].
x
);
}
// pabs
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pabs
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
_mm512_abs_ph
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pabs
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
_mm256_abs_ph
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pabs
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
_mm_abs_ph
(
a
);
}
// psignbit
template
<
>
EIGEN_STRONG_INLINE
Packet32h
psignbit
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
_mm512_castsi512_ph
(
_mm512_srai_epi16
(
_mm512_castph_si512
(
a
),
15
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
psignbit
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
_mm256_castsi256_ph
(
_mm256_srai_epi16
(
_mm256_castph_si256
(
a
),
15
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
psignbit
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
_mm_castsi128_ph
(
_mm_srai_epi16
(
_mm_castph_si128
(
a
),
15
));
}
// pmin
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pmin
<
Packet32h
>
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_min_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmin
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_min_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pmin
<
Packet8h
>
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_min_ph
(
a
,
b
);
}
// pmax
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pmax
<
Packet32h
>
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_max_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmax
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_max_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pmax
<
Packet8h
>
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_max_ph
(
a
,
b
);
}
// plset
template
<
>
EIGEN_STRONG_INLINE
Packet32h
plset
<
Packet32h
>
(
const
half
&
a
)
{
return
_mm512_add_ph
(
pset1
<
Packet32h
>
(
a
),
_mm512_set_ph
(
31
,
30
,
29
,
28
,
27
,
26
,
25
,
24
,
23
,
22
,
21
,
20
,
19
,
18
,
17
,
16
,
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
plset
<
Packet16h
>
(
const
half
&
a
)
{
return
_mm256_add_ph
(
pset1
<
Packet16h
>
(
a
),
_mm256_set_ph
(
15
,
14
,
13
,
12
,
11
,
10
,
9
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
plset
<
Packet8h
>
(
const
half
&
a
)
{
return
_mm_add_ph
(
pset1
<
Packet8h
>
(
a
),
_mm_set_ph
(
7
,
6
,
5
,
4
,
3
,
2
,
1
,
0
));
}
// por
template
<
>
EIGEN_STRONG_INLINE
Packet32h
por
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_castsi512_ph
(
_mm512_or_si512
(
_mm512_castph_si512
(
a
),
_mm512_castph_si512
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
por
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_castsi256_ph
(
_mm256_or_si256
(
_mm256_castph_si256
(
a
),
_mm256_castph_si256
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
por
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_castsi128_ph
(
_mm_or_si128
(
_mm_castph_si128
(
a
),
_mm_castph_si128
(
b
)));
}
// pxor
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pxor
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_castsi512_ph
(
_mm512_xor_si512
(
_mm512_castph_si512
(
a
),
_mm512_castph_si512
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pxor
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_castsi256_ph
(
_mm256_xor_si256
(
_mm256_castph_si256
(
a
),
_mm256_castph_si256
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pxor
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_castsi128_ph
(
_mm_xor_si128
(
_mm_castph_si128
(
a
),
_mm_castph_si128
(
b
)));
}
// pand
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pand
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_castsi512_ph
(
_mm512_and_si512
(
_mm512_castph_si512
(
a
),
_mm512_castph_si512
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pand
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_castsi256_ph
(
_mm256_and_si256
(
_mm256_castph_si256
(
a
),
_mm256_castph_si256
(
b
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pand
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_castsi128_ph
(
_mm_and_si128
(
_mm_castph_si128
(
a
),
_mm_castph_si128
(
b
)));
}
// pandnot
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pandnot
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_castsi512_ph
(
_mm512_andnot_si512
(
_mm512_castph_si512
(
b
),
_mm512_castph_si512
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pandnot
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_castsi256_ph
(
_mm256_andnot_si256
(
_mm256_castph_si256
(
b
),
_mm256_castph_si256
(
a
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pandnot
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_castsi128_ph
(
_mm_andnot_si128
(
_mm_castph_si128
(
b
),
_mm_castph_si128
(
a
)));
}
// pselect
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet32h
pselect
(
const
Packet32h
&
mask
,
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
__mmask32
mask32
=
_mm512_cmp_epi16_mask
(
_mm512_castph_si512
(
mask
),
_mm512_setzero_epi32
(),
_MM_CMPINT_EQ
);
return
_mm512_mask_blend_ph
(
mask32
,
a
,
b
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet16h
pselect
(
const
Packet16h
&
mask
,
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
__mmask16
mask16
=
_mm256_cmp_epi16_mask
(
_mm256_castph_si256
(
mask
),
_mm256_setzero_si256
(),
_MM_CMPINT_EQ
);
return
_mm256_mask_blend_ph
(
mask16
,
a
,
b
);
}
template
<
>
EIGEN_DEVICE_FUNC
inline
Packet8h
pselect
(
const
Packet8h
&
mask
,
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
__mmask8
mask8
=
_mm_cmp_epi16_mask
(
_mm_castph_si128
(
mask
),
_mm_setzero_si128
(),
_MM_CMPINT_EQ
);
return
_mm_mask_blend_ph
(
mask8
,
a
,
b
);
}
// pcmp_eq
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pcmp_eq
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
__mmask32
mask
=
_mm512_cmp_ph_mask
(
a
,
b
,
_CMP_EQ_OQ
);
return
_mm512_castsi512_ph
(
_mm512_mask_set1_epi16
(
_mm512_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcmp_eq
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
__mmask16
mask
=
_mm256_cmp_ph_mask
(
a
,
b
,
_CMP_EQ_OQ
);
return
_mm256_castsi256_ph
(
_mm256_mask_set1_epi16
(
_mm256_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pcmp_eq
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
__mmask8
mask
=
_mm_cmp_ph_mask
(
a
,
b
,
_CMP_EQ_OQ
);
return
_mm_castsi128_ph
(
_mm_mask_set1_epi16
(
_mm_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
// pcmp_le
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pcmp_le
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
__mmask32
mask
=
_mm512_cmp_ph_mask
(
a
,
b
,
_CMP_LE_OQ
);
return
_mm512_castsi512_ph
(
_mm512_mask_set1_epi16
(
_mm512_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcmp_le
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
__mmask16
mask
=
_mm256_cmp_ph_mask
(
a
,
b
,
_CMP_LE_OQ
);
return
_mm256_castsi256_ph
(
_mm256_mask_set1_epi16
(
_mm256_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pcmp_le
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
__mmask8
mask
=
_mm_cmp_ph_mask
(
a
,
b
,
_CMP_LE_OQ
);
return
_mm_castsi128_ph
(
_mm_mask_set1_epi16
(
_mm_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
// pcmp_lt
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pcmp_lt
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
__mmask32
mask
=
_mm512_cmp_ph_mask
(
a
,
b
,
_CMP_LT_OQ
);
return
_mm512_castsi512_ph
(
_mm512_mask_set1_epi16
(
_mm512_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcmp_lt
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
__mmask16
mask
=
_mm256_cmp_ph_mask
(
a
,
b
,
_CMP_LT_OQ
);
return
_mm256_castsi256_ph
(
_mm256_mask_set1_epi16
(
_mm256_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pcmp_lt
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
__mmask8
mask
=
_mm_cmp_ph_mask
(
a
,
b
,
_CMP_LT_OQ
);
return
_mm_castsi128_ph
(
_mm_mask_set1_epi16
(
_mm_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
// pcmp_lt_or_nan
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pcmp_lt_or_nan
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
__mmask32
mask
=
_mm512_cmp_ph_mask
(
a
,
b
,
_CMP_NGE_UQ
);
return
_mm512_castsi512_ph
(
_mm512_mask_set1_epi16
(
_mm512_set1_epi16
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcmp_lt_or_nan
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
__mmask16
mask
=
_mm256_cmp_ph_mask
(
a
,
b
,
_CMP_NGE_UQ
);
return
_mm256_castsi256_ph
(
_mm256_mask_set1_epi16
(
_mm256_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pcmp_lt_or_nan
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
__mmask8
mask
=
_mm_cmp_ph_mask
(
a
,
b
,
_CMP_NGE_UQ
);
return
_mm_castsi128_ph
(
_mm_mask_set1_epi16
(
_mm_set1_epi32
(
0
),
mask
,
static_cast
<
short
>
(
0xffffu
)));
}
// padd
template
<
>
EIGEN_STRONG_INLINE
Packet32h
padd
<
Packet32h
>
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_add_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
padd
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_add_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
padd
<
Packet8h
>
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_add_ph
(
a
,
b
);
}
// psub
template
<
>
EIGEN_STRONG_INLINE
Packet32h
psub
<
Packet32h
>
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_sub_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
psub
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_sub_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
psub
<
Packet8h
>
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_sub_ph
(
a
,
b
);
}
// pmul
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pmul
<
Packet32h
>
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_mul_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmul
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_mul_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pmul
<
Packet8h
>
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_mul_ph
(
a
,
b
);
}
// pdiv
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pdiv
<
Packet32h
>
(
const
Packet32h
&
a
,
const
Packet32h
&
b
)
{
return
_mm512_div_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pdiv
<
Packet16h
>
(
const
Packet16h
&
a
,
const
Packet16h
&
b
)
{
return
_mm256_div_ph
(
a
,
b
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pdiv
<
Packet8h
>
(
const
Packet8h
&
a
,
const
Packet8h
&
b
)
{
return
_mm_div_ph
(
a
,
b
);
;
}
// pround
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pround
<
Packet32h
>
(
const
Packet32h
&
a
)
{
// Work-around for default std::round rounding mode.
// Mask for the sign bit.
const
Packet32h
signMask
=
pset1frombits
<
Packet32h
>
(
static_cast
<
numext
::
uint16_t
>
(
static_cast
<
std
::
uint16_t
>
(
0x8000u
)));
// The largest half-precision float less than 0.5.
const
Packet32h
prev0dot5
=
pset1frombits
<
Packet32h
>
(
static_cast
<
numext
::
uint16_t
>
(
0x37FFu
));
return
_mm512_roundscale_ph
(
padd
(
por
(
pand
(
a
,
signMask
),
prev0dot5
),
a
),
_MM_FROUND_TO_ZERO
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pround
<
Packet16h
>
(
const
Packet16h
&
a
)
{
// Work-around for default std::round rounding mode.
// Mask for the sign bit.
const
Packet16h
signMask
=
pset1frombits
<
Packet16h
>
(
static_cast
<
numext
::
uint16_t
>
(
static_cast
<
std
::
uint16_t
>
(
0x8000u
)));
// The largest half-precision float less than 0.5.
const
Packet16h
prev0dot5
=
pset1frombits
<
Packet16h
>
(
static_cast
<
numext
::
uint16_t
>
(
0x37FFu
));
return
_mm256_roundscale_ph
(
padd
(
por
(
pand
(
a
,
signMask
),
prev0dot5
),
a
),
_MM_FROUND_TO_ZERO
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pround
<
Packet8h
>
(
const
Packet8h
&
a
)
{
// Work-around for default std::round rounding mode.
// Mask for the sign bit.
const
Packet8h
signMask
=
pset1frombits
<
Packet8h
>
(
static_cast
<
numext
::
uint16_t
>
(
static_cast
<
std
::
uint16_t
>
(
0x8000u
)));
// The largest half-precision float less than 0.5.
const
Packet8h
prev0dot5
=
pset1frombits
<
Packet8h
>
(
static_cast
<
numext
::
uint16_t
>
(
0x37FFu
));
return
_mm_roundscale_ph
(
padd
(
por
(
pand
(
a
,
signMask
),
prev0dot5
),
a
),
_MM_FROUND_TO_ZERO
);
}
// print
template
<
>
EIGEN_STRONG_INLINE
Packet32h
print
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
_mm512_roundscale_ph
(
a
,
_MM_FROUND_CUR_DIRECTION
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
print
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
_mm256_roundscale_ph
(
a
,
_MM_FROUND_CUR_DIRECTION
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
print
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
_mm_roundscale_ph
(
a
,
_MM_FROUND_CUR_DIRECTION
);
}
// pceil
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pceil
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
_mm512_roundscale_ph
(
a
,
_MM_FROUND_TO_POS_INF
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pceil
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
_mm256_roundscale_ph
(
a
,
_MM_FROUND_TO_POS_INF
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pceil
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
_mm_roundscale_ph
(
a
,
_MM_FROUND_TO_POS_INF
);
}
// pfloor
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pfloor
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
_mm512_roundscale_ph
(
a
,
_MM_FROUND_TO_NEG_INF
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pfloor
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
_mm256_roundscale_ph
(
a
,
_MM_FROUND_TO_NEG_INF
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pfloor
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
_mm_roundscale_ph
(
a
,
_MM_FROUND_TO_NEG_INF
);
}
// ptrunc
template
<
>
EIGEN_STRONG_INLINE
Packet32h
ptrunc
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
_mm512_roundscale_ph
(
a
,
_MM_FROUND_TO_ZERO
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
ptrunc
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
_mm256_roundscale_ph
(
a
,
_MM_FROUND_TO_ZERO
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
ptrunc
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
_mm_roundscale_ph
(
a
,
_MM_FROUND_TO_ZERO
);
}
// predux
template
<
>
EIGEN_STRONG_INLINE
half
predux
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
half
(
_mm512_reduce_add_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
half
predux
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
half
(
_mm256_reduce_add_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
half
predux
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
half
(
_mm_reduce_add_ph
(
a
));
}
// predux_half_dowto4
template
<
>
EIGEN_STRONG_INLINE
Packet16h
predux_half_dowto4
<
Packet32h
>
(
const
Packet32h
&
a
)
{
const
__m512i
bits
=
_mm512_castph_si512
(
a
);
Packet16h
lo
=
_mm256_castsi256_ph
(
_mm512_castsi512_si256
(
bits
));
Packet16h
hi
=
_mm256_castsi256_ph
(
_mm512_extracti64x4_epi64
(
bits
,
1
));
return
padd
(
lo
,
hi
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
predux_half_dowto4
<
Packet16h
>
(
const
Packet16h
&
a
)
{
Packet8h
lo
=
_mm_castsi128_ph
(
_mm256_castsi256_si128
(
_mm256_castph_si256
(
a
)));
Packet8h
hi
=
_mm_castps_ph
(
_mm256_extractf128_ps
(
_mm256_castph_ps
(
a
),
1
));
return
padd
(
lo
,
hi
);
}
// predux_max
template
<
>
EIGEN_STRONG_INLINE
half
predux_max
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
half
(
_mm512_reduce_max_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
half
predux_max
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
half
(
_mm256_reduce_max_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
half
predux_max
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
half
(
_mm_reduce_max_ph
(
a
));
}
// predux_min
template
<
>
EIGEN_STRONG_INLINE
half
predux_min
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
half
(
_mm512_reduce_min_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
half
predux_min
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
half
(
_mm256_reduce_min_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
half
predux_min
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
half
(
_mm_reduce_min_ph
(
a
));
}
// predux_mul
template
<
>
EIGEN_STRONG_INLINE
half
predux_mul
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
half
(
_mm512_reduce_mul_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
half
predux_mul
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
half
(
_mm256_reduce_mul_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
half
predux_mul
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
half
(
_mm_reduce_mul_ph
(
a
));
}
#ifdef EIGEN_VECTORIZE_FMA
// pmadd
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pmadd
(
const
Packet32h
&
a
,
const
Packet32h
&
b
,
const
Packet32h
&
c
)
{
return
_mm512_fmadd_ph
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmadd
(
const
Packet16h
&
a
,
const
Packet16h
&
b
,
const
Packet16h
&
c
)
{
return
_mm256_fmadd_ph
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pmadd
(
const
Packet8h
&
a
,
const
Packet8h
&
b
,
const
Packet8h
&
c
)
{
return
_mm_fmadd_ph
(
a
,
b
,
c
);
}
// pmsub
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pmsub
(
const
Packet32h
&
a
,
const
Packet32h
&
b
,
const
Packet32h
&
c
)
{
return
_mm512_fmsub_ph
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pmsub
(
const
Packet16h
&
a
,
const
Packet16h
&
b
,
const
Packet16h
&
c
)
{
return
_mm256_fmsub_ph
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pmsub
(
const
Packet8h
&
a
,
const
Packet8h
&
b
,
const
Packet8h
&
c
)
{
return
_mm_fmsub_ph
(
a
,
b
,
c
);
}
// pnmadd
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pnmadd
(
const
Packet32h
&
a
,
const
Packet32h
&
b
,
const
Packet32h
&
c
)
{
return
_mm512_fnmadd_ph
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pnmadd
(
const
Packet16h
&
a
,
const
Packet16h
&
b
,
const
Packet16h
&
c
)
{
return
_mm256_fnmadd_ph
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pnmadd
(
const
Packet8h
&
a
,
const
Packet8h
&
b
,
const
Packet8h
&
c
)
{
return
_mm_fnmadd_ph
(
a
,
b
,
c
);
}
// pnmsub
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pnmsub
(
const
Packet32h
&
a
,
const
Packet32h
&
b
,
const
Packet32h
&
c
)
{
return
_mm512_fnmsub_ph
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pnmsub
(
const
Packet16h
&
a
,
const
Packet16h
&
b
,
const
Packet16h
&
c
)
{
return
_mm256_fnmsub_ph
(
a
,
b
,
c
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pnmsub
(
const
Packet8h
&
a
,
const
Packet8h
&
b
,
const
Packet8h
&
c
)
{
return
_mm_fnmsub_ph
(
a
,
b
,
c
);
}
#endif
// pnegate
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pnegate
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
_mm512_castsi512_ph
(
_mm512_xor_si512
(
_mm512_castph_si512
(
a
),
_mm512_set1_epi16
(
static_cast
<
std
::
uint16_t
>
(
0x8000u
))));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pnegate
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
_mm256_castsi256_ph
(
_mm256_xor_si256
(
_mm256_castph_si256
(
a
),
_mm256_set1_epi16
(
static_cast
<
std
::
uint16_t
>
(
0x8000u
))));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pnegate
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
_mm_castsi128_ph
(
_mm_xor_si128
(
_mm_castph_si128
(
a
),
_mm_set1_epi16
(
static_cast
<
std
::
uint16_t
>
(
0x8000u
))));
}
// pconj
// Nothing, packets are real.
// psqrt
template
<
>
EIGEN_STRONG_INLINE
Packet32h
psqrt
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
generic_sqrt_newton_step
<
Packet32h
>::
run
(
a
,
_mm512_rsqrt_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
psqrt
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
generic_sqrt_newton_step
<
Packet16h
>::
run
(
a
,
_mm256_rsqrt_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
psqrt
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
generic_sqrt_newton_step
<
Packet8h
>::
run
(
a
,
_mm_rsqrt_ph
(
a
));
}
// prsqrt
template
<
>
EIGEN_STRONG_INLINE
Packet32h
prsqrt
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
generic_rsqrt_newton_step
<
Packet32h
,
/*Steps=*/
1
>::
run
(
a
,
_mm512_rsqrt_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
prsqrt
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
generic_rsqrt_newton_step
<
Packet16h
,
/*Steps=*/
1
>::
run
(
a
,
_mm256_rsqrt_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
prsqrt
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
generic_rsqrt_newton_step
<
Packet8h
,
/*Steps=*/
1
>::
run
(
a
,
_mm_rsqrt_ph
(
a
));
}
// preciprocal
template
<
>
EIGEN_STRONG_INLINE
Packet32h
preciprocal
<
Packet32h
>
(
const
Packet32h
&
a
)
{
return
generic_reciprocal_newton_step
<
Packet32h
,
/*Steps=*/
1
>::
run
(
a
,
_mm512_rcp_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
preciprocal
<
Packet16h
>
(
const
Packet16h
&
a
)
{
return
generic_reciprocal_newton_step
<
Packet16h
,
/*Steps=*/
1
>::
run
(
a
,
_mm256_rcp_ph
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
preciprocal
<
Packet8h
>
(
const
Packet8h
&
a
)
{
return
generic_reciprocal_newton_step
<
Packet8h
,
/*Steps=*/
1
>::
run
(
a
,
_mm_rcp_ph
(
a
));
}
// ptranspose
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet32h
,
32
>&
a
)
{
__m512i
t
[
32
];
EIGEN_UNROLL_LOOP
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
t
[
2
*
i
]
=
_mm512_unpacklo_epi16
(
_mm512_castph_si512
(
a
.
packet
[
2
*
i
]),
_mm512_castph_si512
(
a
.
packet
[
2
*
i
+
1
]));
t
[
2
*
i
+
1
]
=
_mm512_unpackhi_epi16
(
_mm512_castph_si512
(
a
.
packet
[
2
*
i
]),
_mm512_castph_si512
(
a
.
packet
[
2
*
i
+
1
]));
}
__m512i
p
[
32
];
EIGEN_UNROLL_LOOP
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
p
[
4
*
i
]
=
_mm512_unpacklo_epi32
(
t
[
4
*
i
],
t
[
4
*
i
+
2
]);
p
[
4
*
i
+
1
]
=
_mm512_unpackhi_epi32
(
t
[
4
*
i
],
t
[
4
*
i
+
2
]);
p
[
4
*
i
+
2
]
=
_mm512_unpacklo_epi32
(
t
[
4
*
i
+
1
],
t
[
4
*
i
+
3
]);
p
[
4
*
i
+
3
]
=
_mm512_unpackhi_epi32
(
t
[
4
*
i
+
1
],
t
[
4
*
i
+
3
]);
}
__m512i
q
[
32
];
EIGEN_UNROLL_LOOP
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
q
[
8
*
i
]
=
_mm512_unpacklo_epi64
(
p
[
8
*
i
],
p
[
8
*
i
+
4
]);
q
[
8
*
i
+
1
]
=
_mm512_unpackhi_epi64
(
p
[
8
*
i
],
p
[
8
*
i
+
4
]);
q
[
8
*
i
+
2
]
=
_mm512_unpacklo_epi64
(
p
[
8
*
i
+
1
],
p
[
8
*
i
+
5
]);
q
[
8
*
i
+
3
]
=
_mm512_unpackhi_epi64
(
p
[
8
*
i
+
1
],
p
[
8
*
i
+
5
]);
q
[
8
*
i
+
4
]
=
_mm512_unpacklo_epi64
(
p
[
8
*
i
+
2
],
p
[
8
*
i
+
6
]);
q
[
8
*
i
+
5
]
=
_mm512_unpackhi_epi64
(
p
[
8
*
i
+
2
],
p
[
8
*
i
+
6
]);
q
[
8
*
i
+
6
]
=
_mm512_unpacklo_epi64
(
p
[
8
*
i
+
3
],
p
[
8
*
i
+
7
]);
q
[
8
*
i
+
7
]
=
_mm512_unpackhi_epi64
(
p
[
8
*
i
+
3
],
p
[
8
*
i
+
7
]);
}
__m512i
f
[
32
];
#define PACKET32H_TRANSPOSE_HELPER(X, Y) \
do
{
\
f
[
Y
*
8
]
=
_mm512_inserti32x4
(
f
[
Y
*
8
],
_mm512_extracti32x4_epi32
(
q
[
X
*
8
],
Y
),
X
);
\
f
[
Y
*
8
+
1
]
=
_mm512_inserti32x4
(
f
[
Y
*
8
+
1
],
_mm512_extracti32x4_epi32
(
q
[
X
*
8
+
1
],
Y
),
X
);
\
f
[
Y
*
8
+
2
]
=
_mm512_inserti32x4
(
f
[
Y
*
8
+
2
],
_mm512_extracti32x4_epi32
(
q
[
X
*
8
+
2
],
Y
),
X
);
\
f
[
Y
*
8
+
3
]
=
_mm512_inserti32x4
(
f
[
Y
*
8
+
3
],
_mm512_extracti32x4_epi32
(
q
[
X
*
8
+
3
],
Y
),
X
);
\
f
[
Y
*
8
+
4
]
=
_mm512_inserti32x4
(
f
[
Y
*
8
+
4
],
_mm512_extracti32x4_epi32
(
q
[
X
*
8
+
4
],
Y
),
X
);
\
f
[
Y
*
8
+
5
]
=
_mm512_inserti32x4
(
f
[
Y
*
8
+
5
],
_mm512_extracti32x4_epi32
(
q
[
X
*
8
+
5
],
Y
),
X
);
\
f
[
Y
*
8
+
6
]
=
_mm512_inserti32x4
(
f
[
Y
*
8
+
6
],
_mm512_extracti32x4_epi32
(
q
[
X
*
8
+
6
],
Y
),
X
);
\
f
[
Y
*
8
+
7
]
=
_mm512_inserti32x4
(
f
[
Y
*
8
+
7
],
_mm512_extracti32x4_epi32
(
q
[
X
*
8
+
7
],
Y
),
X
);
\
}
while
(
false
);
PACKET32H_TRANSPOSE_HELPER
(
0
,
0
);
PACKET32H_TRANSPOSE_HELPER
(
1
,
1
);
PACKET32H_TRANSPOSE_HELPER
(
2
,
2
);
PACKET32H_TRANSPOSE_HELPER
(
3
,
3
);
PACKET32H_TRANSPOSE_HELPER
(
1
,
0
);
PACKET32H_TRANSPOSE_HELPER
(
2
,
0
);
PACKET32H_TRANSPOSE_HELPER
(
3
,
0
);
PACKET32H_TRANSPOSE_HELPER
(
2
,
1
);
PACKET32H_TRANSPOSE_HELPER
(
3
,
1
);
PACKET32H_TRANSPOSE_HELPER
(
3
,
2
);
PACKET32H_TRANSPOSE_HELPER
(
0
,
1
);
PACKET32H_TRANSPOSE_HELPER
(
0
,
2
);
PACKET32H_TRANSPOSE_HELPER
(
0
,
3
);
PACKET32H_TRANSPOSE_HELPER
(
1
,
2
);
PACKET32H_TRANSPOSE_HELPER
(
1
,
3
);
PACKET32H_TRANSPOSE_HELPER
(
2
,
3
);
#undef PACKET32H_TRANSPOSE_HELPER
EIGEN_UNROLL_LOOP
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
a
.
packet
[
i
]
=
_mm512_castsi512_ph
(
f
[
i
]);
}
}
EIGEN_DEVICE_FUNC
inline
void
ptranspose
(
PacketBlock
<
Packet32h
,
4
>&
a
)
{
__m512i
p0
,
p1
,
p2
,
p3
,
t0
,
t1
,
t2
,
t3
,
a0
,
a1
,
a2
,
a3
;
t0
=
_mm512_unpacklo_epi16
(
_mm512_castph_si512
(
a
.
packet
[
0
]),
_mm512_castph_si512
(
a
.
packet
[
1
]));
t1
=
_mm512_unpackhi_epi16
(
_mm512_castph_si512
(
a
.
packet
[
0
]),
_mm512_castph_si512
(
a
.
packet
[
1
]));
t2
=
_mm512_unpacklo_epi16
(
_mm512_castph_si512
(
a
.
packet
[
2
]),
_mm512_castph_si512
(
a
.
packet
[
3
]));
t3
=
_mm512_unpackhi_epi16
(
_mm512_castph_si512
(
a
.
packet
[
2
]),
_mm512_castph_si512
(
a
.
packet
[
3
]));
p0
=
_mm512_unpacklo_epi32
(
t0
,
t2
);
p1
=
_mm512_unpackhi_epi32
(
t0
,
t2
);
p2
=
_mm512_unpacklo_epi32
(
t1
,
t3
);
p3
=
_mm512_unpackhi_epi32
(
t1
,
t3
);
a0
=
p0
;
a1
=
p1
;
a2
=
p2
;
a3
=
p3
;
a0
=
_mm512_inserti32x4
(
a0
,
_mm512_extracti32x4_epi32
(
p1
,
0
),
1
);
a1
=
_mm512_inserti32x4
(
a1
,
_mm512_extracti32x4_epi32
(
p0
,
1
),
0
);
a0
=
_mm512_inserti32x4
(
a0
,
_mm512_extracti32x4_epi32
(
p2
,
0
),
2
);
a2
=
_mm512_inserti32x4
(
a2
,
_mm512_extracti32x4_epi32
(
p0
,
2
),
0
);
a0
=
_mm512_inserti32x4
(
a0
,
_mm512_extracti32x4_epi32
(
p3
,
0
),
3
);
a3
=
_mm512_inserti32x4
(
a3
,
_mm512_extracti32x4_epi32
(
p0
,
3
),
0
);
a1
=
_mm512_inserti32x4
(
a1
,
_mm512_extracti32x4_epi32
(
p2
,
1
),
2
);
a2
=
_mm512_inserti32x4
(
a2
,
_mm512_extracti32x4_epi32
(
p1
,
2
),
1
);
a2
=
_mm512_inserti32x4
(
a2
,
_mm512_extracti32x4_epi32
(
p3
,
2
),
3
);
a3
=
_mm512_inserti32x4
(
a3
,
_mm512_extracti32x4_epi32
(
p2
,
3
),
2
);
a1
=
_mm512_inserti32x4
(
a1
,
_mm512_extracti32x4_epi32
(
p3
,
1
),
3
);
a3
=
_mm512_inserti32x4
(
a3
,
_mm512_extracti32x4_epi32
(
p1
,
3
),
1
);
a
.
packet
[
0
]
=
_mm512_castsi512_ph
(
a0
);
a
.
packet
[
1
]
=
_mm512_castsi512_ph
(
a1
);
a
.
packet
[
2
]
=
_mm512_castsi512_ph
(
a2
);
a
.
packet
[
3
]
=
_mm512_castsi512_ph
(
a3
);
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet16h
,
16
>&
kernel
)
{
__m256i
a
=
_mm256_castph_si256
(
kernel
.
packet
[
0
]);
__m256i
b
=
_mm256_castph_si256
(
kernel
.
packet
[
1
]);
__m256i
c
=
_mm256_castph_si256
(
kernel
.
packet
[
2
]);
__m256i
d
=
_mm256_castph_si256
(
kernel
.
packet
[
3
]);
__m256i
e
=
_mm256_castph_si256
(
kernel
.
packet
[
4
]);
__m256i
f
=
_mm256_castph_si256
(
kernel
.
packet
[
5
]);
__m256i
g
=
_mm256_castph_si256
(
kernel
.
packet
[
6
]);
__m256i
h
=
_mm256_castph_si256
(
kernel
.
packet
[
7
]);
__m256i
i
=
_mm256_castph_si256
(
kernel
.
packet
[
8
]);
__m256i
j
=
_mm256_castph_si256
(
kernel
.
packet
[
9
]);
__m256i
k
=
_mm256_castph_si256
(
kernel
.
packet
[
10
]);
__m256i
l
=
_mm256_castph_si256
(
kernel
.
packet
[
11
]);
__m256i
m
=
_mm256_castph_si256
(
kernel
.
packet
[
12
]);
__m256i
n
=
_mm256_castph_si256
(
kernel
.
packet
[
13
]);
__m256i
o
=
_mm256_castph_si256
(
kernel
.
packet
[
14
]);
__m256i
p
=
_mm256_castph_si256
(
kernel
.
packet
[
15
]);
__m256i
ab_07
=
_mm256_unpacklo_epi16
(
a
,
b
);
__m256i
cd_07
=
_mm256_unpacklo_epi16
(
c
,
d
);
__m256i
ef_07
=
_mm256_unpacklo_epi16
(
e
,
f
);
__m256i
gh_07
=
_mm256_unpacklo_epi16
(
g
,
h
);
__m256i
ij_07
=
_mm256_unpacklo_epi16
(
i
,
j
);
__m256i
kl_07
=
_mm256_unpacklo_epi16
(
k
,
l
);
__m256i
mn_07
=
_mm256_unpacklo_epi16
(
m
,
n
);
__m256i
op_07
=
_mm256_unpacklo_epi16
(
o
,
p
);
__m256i
ab_8f
=
_mm256_unpackhi_epi16
(
a
,
b
);
__m256i
cd_8f
=
_mm256_unpackhi_epi16
(
c
,
d
);
__m256i
ef_8f
=
_mm256_unpackhi_epi16
(
e
,
f
);
__m256i
gh_8f
=
_mm256_unpackhi_epi16
(
g
,
h
);
__m256i
ij_8f
=
_mm256_unpackhi_epi16
(
i
,
j
);
__m256i
kl_8f
=
_mm256_unpackhi_epi16
(
k
,
l
);
__m256i
mn_8f
=
_mm256_unpackhi_epi16
(
m
,
n
);
__m256i
op_8f
=
_mm256_unpackhi_epi16
(
o
,
p
);
__m256i
abcd_03
=
_mm256_unpacklo_epi32
(
ab_07
,
cd_07
);
__m256i
abcd_47
=
_mm256_unpackhi_epi32
(
ab_07
,
cd_07
);
__m256i
efgh_03
=
_mm256_unpacklo_epi32
(
ef_07
,
gh_07
);
__m256i
efgh_47
=
_mm256_unpackhi_epi32
(
ef_07
,
gh_07
);
__m256i
ijkl_03
=
_mm256_unpacklo_epi32
(
ij_07
,
kl_07
);
__m256i
ijkl_47
=
_mm256_unpackhi_epi32
(
ij_07
,
kl_07
);
__m256i
mnop_03
=
_mm256_unpacklo_epi32
(
mn_07
,
op_07
);
__m256i
mnop_47
=
_mm256_unpackhi_epi32
(
mn_07
,
op_07
);
__m256i
abcd_8b
=
_mm256_unpacklo_epi32
(
ab_8f
,
cd_8f
);
__m256i
abcd_cf
=
_mm256_unpackhi_epi32
(
ab_8f
,
cd_8f
);
__m256i
efgh_8b
=
_mm256_unpacklo_epi32
(
ef_8f
,
gh_8f
);
__m256i
efgh_cf
=
_mm256_unpackhi_epi32
(
ef_8f
,
gh_8f
);
__m256i
ijkl_8b
=
_mm256_unpacklo_epi32
(
ij_8f
,
kl_8f
);
__m256i
ijkl_cf
=
_mm256_unpackhi_epi32
(
ij_8f
,
kl_8f
);
__m256i
mnop_8b
=
_mm256_unpacklo_epi32
(
mn_8f
,
op_8f
);
__m256i
mnop_cf
=
_mm256_unpackhi_epi32
(
mn_8f
,
op_8f
);
__m256i
abcdefgh_01
=
_mm256_unpacklo_epi64
(
abcd_03
,
efgh_03
);
__m256i
abcdefgh_23
=
_mm256_unpackhi_epi64
(
abcd_03
,
efgh_03
);
__m256i
ijklmnop_01
=
_mm256_unpacklo_epi64
(
ijkl_03
,
mnop_03
);
__m256i
ijklmnop_23
=
_mm256_unpackhi_epi64
(
ijkl_03
,
mnop_03
);
__m256i
abcdefgh_45
=
_mm256_unpacklo_epi64
(
abcd_47
,
efgh_47
);
__m256i
abcdefgh_67
=
_mm256_unpackhi_epi64
(
abcd_47
,
efgh_47
);
__m256i
ijklmnop_45
=
_mm256_unpacklo_epi64
(
ijkl_47
,
mnop_47
);
__m256i
ijklmnop_67
=
_mm256_unpackhi_epi64
(
ijkl_47
,
mnop_47
);
__m256i
abcdefgh_89
=
_mm256_unpacklo_epi64
(
abcd_8b
,
efgh_8b
);
__m256i
abcdefgh_ab
=
_mm256_unpackhi_epi64
(
abcd_8b
,
efgh_8b
);
__m256i
ijklmnop_89
=
_mm256_unpacklo_epi64
(
ijkl_8b
,
mnop_8b
);
__m256i
ijklmnop_ab
=
_mm256_unpackhi_epi64
(
ijkl_8b
,
mnop_8b
);
__m256i
abcdefgh_cd
=
_mm256_unpacklo_epi64
(
abcd_cf
,
efgh_cf
);
__m256i
abcdefgh_ef
=
_mm256_unpackhi_epi64
(
abcd_cf
,
efgh_cf
);
__m256i
ijklmnop_cd
=
_mm256_unpacklo_epi64
(
ijkl_cf
,
mnop_cf
);
__m256i
ijklmnop_ef
=
_mm256_unpackhi_epi64
(
ijkl_cf
,
mnop_cf
);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
__m256i
a_p_0
=
_mm256_permute2x128_si256
(
abcdefgh_01
,
ijklmnop_01
,
0x20
);
__m256i
a_p_1
=
_mm256_permute2x128_si256
(
abcdefgh_23
,
ijklmnop_23
,
0x20
);
__m256i
a_p_2
=
_mm256_permute2x128_si256
(
abcdefgh_45
,
ijklmnop_45
,
0x20
);
__m256i
a_p_3
=
_mm256_permute2x128_si256
(
abcdefgh_67
,
ijklmnop_67
,
0x20
);
__m256i
a_p_4
=
_mm256_permute2x128_si256
(
abcdefgh_89
,
ijklmnop_89
,
0x20
);
__m256i
a_p_5
=
_mm256_permute2x128_si256
(
abcdefgh_ab
,
ijklmnop_ab
,
0x20
);
__m256i
a_p_6
=
_mm256_permute2x128_si256
(
abcdefgh_cd
,
ijklmnop_cd
,
0x20
);
__m256i
a_p_7
=
_mm256_permute2x128_si256
(
abcdefgh_ef
,
ijklmnop_ef
,
0x20
);
__m256i
a_p_8
=
_mm256_permute2x128_si256
(
abcdefgh_01
,
ijklmnop_01
,
0x31
);
__m256i
a_p_9
=
_mm256_permute2x128_si256
(
abcdefgh_23
,
ijklmnop_23
,
0x31
);
__m256i
a_p_a
=
_mm256_permute2x128_si256
(
abcdefgh_45
,
ijklmnop_45
,
0x31
);
__m256i
a_p_b
=
_mm256_permute2x128_si256
(
abcdefgh_67
,
ijklmnop_67
,
0x31
);
__m256i
a_p_c
=
_mm256_permute2x128_si256
(
abcdefgh_89
,
ijklmnop_89
,
0x31
);
__m256i
a_p_d
=
_mm256_permute2x128_si256
(
abcdefgh_ab
,
ijklmnop_ab
,
0x31
);
__m256i
a_p_e
=
_mm256_permute2x128_si256
(
abcdefgh_cd
,
ijklmnop_cd
,
0x31
);
__m256i
a_p_f
=
_mm256_permute2x128_si256
(
abcdefgh_ef
,
ijklmnop_ef
,
0x31
);
kernel
.
packet
[
0
]
=
_mm256_castsi256_ph
(
a_p_0
);
kernel
.
packet
[
1
]
=
_mm256_castsi256_ph
(
a_p_1
);
kernel
.
packet
[
2
]
=
_mm256_castsi256_ph
(
a_p_2
);
kernel
.
packet
[
3
]
=
_mm256_castsi256_ph
(
a_p_3
);
kernel
.
packet
[
4
]
=
_mm256_castsi256_ph
(
a_p_4
);
kernel
.
packet
[
5
]
=
_mm256_castsi256_ph
(
a_p_5
);
kernel
.
packet
[
6
]
=
_mm256_castsi256_ph
(
a_p_6
);
kernel
.
packet
[
7
]
=
_mm256_castsi256_ph
(
a_p_7
);
kernel
.
packet
[
8
]
=
_mm256_castsi256_ph
(
a_p_8
);
kernel
.
packet
[
9
]
=
_mm256_castsi256_ph
(
a_p_9
);
kernel
.
packet
[
10
]
=
_mm256_castsi256_ph
(
a_p_a
);
kernel
.
packet
[
11
]
=
_mm256_castsi256_ph
(
a_p_b
);
kernel
.
packet
[
12
]
=
_mm256_castsi256_ph
(
a_p_c
);
kernel
.
packet
[
13
]
=
_mm256_castsi256_ph
(
a_p_d
);
kernel
.
packet
[
14
]
=
_mm256_castsi256_ph
(
a_p_e
);
kernel
.
packet
[
15
]
=
_mm256_castsi256_ph
(
a_p_f
);
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet16h
,
8
>&
kernel
)
{
EIGEN_ALIGN64
half
in
[
8
][
16
];
pstore
<
half
>
(
in
[
0
],
kernel
.
packet
[
0
]);
pstore
<
half
>
(
in
[
1
],
kernel
.
packet
[
1
]);
pstore
<
half
>
(
in
[
2
],
kernel
.
packet
[
2
]);
pstore
<
half
>
(
in
[
3
],
kernel
.
packet
[
3
]);
pstore
<
half
>
(
in
[
4
],
kernel
.
packet
[
4
]);
pstore
<
half
>
(
in
[
5
],
kernel
.
packet
[
5
]);
pstore
<
half
>
(
in
[
6
],
kernel
.
packet
[
6
]);
pstore
<
half
>
(
in
[
7
],
kernel
.
packet
[
7
]);
EIGEN_ALIGN64
half
out
[
8
][
16
];
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
out
[
i
][
j
]
=
in
[
j
][
2
*
i
];
}
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
out
[
i
][
j
+
8
]
=
in
[
j
][
2
*
i
+
1
];
}
}
kernel
.
packet
[
0
]
=
pload
<
Packet16h
>
(
out
[
0
]);
kernel
.
packet
[
1
]
=
pload
<
Packet16h
>
(
out
[
1
]);
kernel
.
packet
[
2
]
=
pload
<
Packet16h
>
(
out
[
2
]);
kernel
.
packet
[
3
]
=
pload
<
Packet16h
>
(
out
[
3
]);
kernel
.
packet
[
4
]
=
pload
<
Packet16h
>
(
out
[
4
]);
kernel
.
packet
[
5
]
=
pload
<
Packet16h
>
(
out
[
5
]);
kernel
.
packet
[
6
]
=
pload
<
Packet16h
>
(
out
[
6
]);
kernel
.
packet
[
7
]
=
pload
<
Packet16h
>
(
out
[
7
]);
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet16h
,
4
>&
kernel
)
{
EIGEN_ALIGN64
half
in
[
4
][
16
];
pstore
<
half
>
(
in
[
0
],
kernel
.
packet
[
0
]);
pstore
<
half
>
(
in
[
1
],
kernel
.
packet
[
1
]);
pstore
<
half
>
(
in
[
2
],
kernel
.
packet
[
2
]);
pstore
<
half
>
(
in
[
3
],
kernel
.
packet
[
3
]);
EIGEN_ALIGN64
half
out
[
4
][
16
];
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
]
=
in
[
j
][
4
*
i
];
}
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
+
4
]
=
in
[
j
][
4
*
i
+
1
];
}
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
+
8
]
=
in
[
j
][
4
*
i
+
2
];
}
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
+
12
]
=
in
[
j
][
4
*
i
+
3
];
}
}
kernel
.
packet
[
0
]
=
pload
<
Packet16h
>
(
out
[
0
]);
kernel
.
packet
[
1
]
=
pload
<
Packet16h
>
(
out
[
1
]);
kernel
.
packet
[
2
]
=
pload
<
Packet16h
>
(
out
[
2
]);
kernel
.
packet
[
3
]
=
pload
<
Packet16h
>
(
out
[
3
]);
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet8h
,
8
>&
kernel
)
{
__m128i
a
=
_mm_castph_si128
(
kernel
.
packet
[
0
]);
__m128i
b
=
_mm_castph_si128
(
kernel
.
packet
[
1
]);
__m128i
c
=
_mm_castph_si128
(
kernel
.
packet
[
2
]);
__m128i
d
=
_mm_castph_si128
(
kernel
.
packet
[
3
]);
__m128i
e
=
_mm_castph_si128
(
kernel
.
packet
[
4
]);
__m128i
f
=
_mm_castph_si128
(
kernel
.
packet
[
5
]);
__m128i
g
=
_mm_castph_si128
(
kernel
.
packet
[
6
]);
__m128i
h
=
_mm_castph_si128
(
kernel
.
packet
[
7
]);
__m128i
a03b03
=
_mm_unpacklo_epi16
(
a
,
b
);
__m128i
c03d03
=
_mm_unpacklo_epi16
(
c
,
d
);
__m128i
e03f03
=
_mm_unpacklo_epi16
(
e
,
f
);
__m128i
g03h03
=
_mm_unpacklo_epi16
(
g
,
h
);
__m128i
a47b47
=
_mm_unpackhi_epi16
(
a
,
b
);
__m128i
c47d47
=
_mm_unpackhi_epi16
(
c
,
d
);
__m128i
e47f47
=
_mm_unpackhi_epi16
(
e
,
f
);
__m128i
g47h47
=
_mm_unpackhi_epi16
(
g
,
h
);
__m128i
a01b01c01d01
=
_mm_unpacklo_epi32
(
a03b03
,
c03d03
);
__m128i
a23b23c23d23
=
_mm_unpackhi_epi32
(
a03b03
,
c03d03
);
__m128i
e01f01g01h01
=
_mm_unpacklo_epi32
(
e03f03
,
g03h03
);
__m128i
e23f23g23h23
=
_mm_unpackhi_epi32
(
e03f03
,
g03h03
);
__m128i
a45b45c45d45
=
_mm_unpacklo_epi32
(
a47b47
,
c47d47
);
__m128i
a67b67c67d67
=
_mm_unpackhi_epi32
(
a47b47
,
c47d47
);
__m128i
e45f45g45h45
=
_mm_unpacklo_epi32
(
e47f47
,
g47h47
);
__m128i
e67f67g67h67
=
_mm_unpackhi_epi32
(
e47f47
,
g47h47
);
__m128i
a0b0c0d0e0f0g0h0
=
_mm_unpacklo_epi64
(
a01b01c01d01
,
e01f01g01h01
);
__m128i
a1b1c1d1e1f1g1h1
=
_mm_unpackhi_epi64
(
a01b01c01d01
,
e01f01g01h01
);
__m128i
a2b2c2d2e2f2g2h2
=
_mm_unpacklo_epi64
(
a23b23c23d23
,
e23f23g23h23
);
__m128i
a3b3c3d3e3f3g3h3
=
_mm_unpackhi_epi64
(
a23b23c23d23
,
e23f23g23h23
);
__m128i
a4b4c4d4e4f4g4h4
=
_mm_unpacklo_epi64
(
a45b45c45d45
,
e45f45g45h45
);
__m128i
a5b5c5d5e5f5g5h5
=
_mm_unpackhi_epi64
(
a45b45c45d45
,
e45f45g45h45
);
__m128i
a6b6c6d6e6f6g6h6
=
_mm_unpacklo_epi64
(
a67b67c67d67
,
e67f67g67h67
);
__m128i
a7b7c7d7e7f7g7h7
=
_mm_unpackhi_epi64
(
a67b67c67d67
,
e67f67g67h67
);
kernel
.
packet
[
0
]
=
_mm_castsi128_ph
(
a0b0c0d0e0f0g0h0
);
kernel
.
packet
[
1
]
=
_mm_castsi128_ph
(
a1b1c1d1e1f1g1h1
);
kernel
.
packet
[
2
]
=
_mm_castsi128_ph
(
a2b2c2d2e2f2g2h2
);
kernel
.
packet
[
3
]
=
_mm_castsi128_ph
(
a3b3c3d3e3f3g3h3
);
kernel
.
packet
[
4
]
=
_mm_castsi128_ph
(
a4b4c4d4e4f4g4h4
);
kernel
.
packet
[
5
]
=
_mm_castsi128_ph
(
a5b5c5d5e5f5g5h5
);
kernel
.
packet
[
6
]
=
_mm_castsi128_ph
(
a6b6c6d6e6f6g6h6
);
kernel
.
packet
[
7
]
=
_mm_castsi128_ph
(
a7b7c7d7e7f7g7h7
);
}
EIGEN_STRONG_INLINE
void
ptranspose
(
PacketBlock
<
Packet8h
,
4
>&
kernel
)
{
EIGEN_ALIGN32
Eigen
::
half
in
[
4
][
8
];
pstore
<
Eigen
::
half
>
(
in
[
0
],
kernel
.
packet
[
0
]);
pstore
<
Eigen
::
half
>
(
in
[
1
],
kernel
.
packet
[
1
]);
pstore
<
Eigen
::
half
>
(
in
[
2
],
kernel
.
packet
[
2
]);
pstore
<
Eigen
::
half
>
(
in
[
3
],
kernel
.
packet
[
3
]);
EIGEN_ALIGN32
Eigen
::
half
out
[
4
][
8
];
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
]
=
in
[
j
][
2
*
i
];
}
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
out
[
i
][
j
+
4
]
=
in
[
j
][
2
*
i
+
1
];
}
}
kernel
.
packet
[
0
]
=
pload
<
Packet8h
>
(
out
[
0
]);
kernel
.
packet
[
1
]
=
pload
<
Packet8h
>
(
out
[
1
]);
kernel
.
packet
[
2
]
=
pload
<
Packet8h
>
(
out
[
2
]);
kernel
.
packet
[
3
]
=
pload
<
Packet8h
>
(
out
[
3
]);
}
// preverse
template
<
>
EIGEN_STRONG_INLINE
Packet32h
preverse
(
const
Packet32h
&
a
)
{
return
_mm512_permutexvar_ph
(
_mm512_set_epi16
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
),
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
preverse
(
const
Packet16h
&
a
)
{
__m128i
m
=
_mm_setr_epi8
(
14
,
15
,
12
,
13
,
10
,
11
,
8
,
9
,
6
,
7
,
4
,
5
,
2
,
3
,
0
,
1
);
return
_mm256_castsi256_ph
(
_mm256_insertf128_si256
(
_mm256_castsi128_si256
(
_mm_shuffle_epi8
(
_mm256_extractf128_si256
(
_mm256_castph_si256
(
a
),
1
),
m
)),
_mm_shuffle_epi8
(
_mm256_extractf128_si256
(
_mm256_castph_si256
(
a
),
0
),
m
),
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
preverse
(
const
Packet8h
&
a
)
{
__m128i
m
=
_mm_setr_epi8
(
14
,
15
,
12
,
13
,
10
,
11
,
8
,
9
,
6
,
7
,
4
,
5
,
2
,
3
,
0
,
1
);
return
_mm_castsi128_ph
(
_mm_shuffle_epi8
(
_mm_castph_si128
(
a
),
m
));
}
// pscatter
template
<
>
EIGEN_STRONG_INLINE
void
pscatter
<
half
,
Packet32h
>
(
half
*
to
,
const
Packet32h
&
from
,
Index
stride
)
{
EIGEN_ALIGN64
half
aux
[
32
];
pstore
(
aux
,
from
);
EIGEN_UNROLL_LOOP
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
to
[
stride
*
i
]
=
aux
[
i
];
}
}
template
<
>
EIGEN_STRONG_INLINE
void
pscatter
<
half
,
Packet16h
>
(
half
*
to
,
const
Packet16h
&
from
,
Index
stride
)
{
EIGEN_ALIGN64
half
aux
[
16
];
pstore
(
aux
,
from
);
to
[
stride
*
0
]
=
aux
[
0
];
to
[
stride
*
1
]
=
aux
[
1
];
to
[
stride
*
2
]
=
aux
[
2
];
to
[
stride
*
3
]
=
aux
[
3
];
to
[
stride
*
4
]
=
aux
[
4
];
to
[
stride
*
5
]
=
aux
[
5
];
to
[
stride
*
6
]
=
aux
[
6
];
to
[
stride
*
7
]
=
aux
[
7
];
to
[
stride
*
8
]
=
aux
[
8
];
to
[
stride
*
9
]
=
aux
[
9
];
to
[
stride
*
10
]
=
aux
[
10
];
to
[
stride
*
11
]
=
aux
[
11
];
to
[
stride
*
12
]
=
aux
[
12
];
to
[
stride
*
13
]
=
aux
[
13
];
to
[
stride
*
14
]
=
aux
[
14
];
to
[
stride
*
15
]
=
aux
[
15
];
}
template
<
>
EIGEN_STRONG_INLINE
void
pscatter
<
Eigen
::
half
,
Packet8h
>
(
Eigen
::
half
*
to
,
const
Packet8h
&
from
,
Index
stride
)
{
EIGEN_ALIGN32
Eigen
::
half
aux
[
8
];
pstore
(
aux
,
from
);
to
[
stride
*
0
]
=
aux
[
0
];
to
[
stride
*
1
]
=
aux
[
1
];
to
[
stride
*
2
]
=
aux
[
2
];
to
[
stride
*
3
]
=
aux
[
3
];
to
[
stride
*
4
]
=
aux
[
4
];
to
[
stride
*
5
]
=
aux
[
5
];
to
[
stride
*
6
]
=
aux
[
6
];
to
[
stride
*
7
]
=
aux
[
7
];
}
// pgather
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pgather
<
Eigen
::
half
,
Packet32h
>
(
const
Eigen
::
half
*
from
,
Index
stride
)
{
return
_mm512_set_ph
(
from
[
31
*
stride
].
x
,
from
[
30
*
stride
].
x
,
from
[
29
*
stride
].
x
,
from
[
28
*
stride
].
x
,
from
[
27
*
stride
].
x
,
from
[
26
*
stride
].
x
,
from
[
25
*
stride
].
x
,
from
[
24
*
stride
].
x
,
from
[
23
*
stride
].
x
,
from
[
22
*
stride
].
x
,
from
[
21
*
stride
].
x
,
from
[
20
*
stride
].
x
,
from
[
19
*
stride
].
x
,
from
[
18
*
stride
].
x
,
from
[
17
*
stride
].
x
,
from
[
16
*
stride
].
x
,
from
[
15
*
stride
].
x
,
from
[
14
*
stride
].
x
,
from
[
13
*
stride
].
x
,
from
[
12
*
stride
].
x
,
from
[
11
*
stride
].
x
,
from
[
10
*
stride
].
x
,
from
[
9
*
stride
].
x
,
from
[
8
*
stride
].
x
,
from
[
7
*
stride
].
x
,
from
[
6
*
stride
].
x
,
from
[
5
*
stride
].
x
,
from
[
4
*
stride
].
x
,
from
[
3
*
stride
].
x
,
from
[
2
*
stride
].
x
,
from
[
1
*
stride
].
x
,
from
[
0
*
stride
].
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pgather
<
Eigen
::
half
,
Packet16h
>
(
const
Eigen
::
half
*
from
,
Index
stride
)
{
return
_mm256_set_ph
(
from
[
15
*
stride
].
x
,
from
[
14
*
stride
].
x
,
from
[
13
*
stride
].
x
,
from
[
12
*
stride
].
x
,
from
[
11
*
stride
].
x
,
from
[
10
*
stride
].
x
,
from
[
9
*
stride
].
x
,
from
[
8
*
stride
].
x
,
from
[
7
*
stride
].
x
,
from
[
6
*
stride
].
x
,
from
[
5
*
stride
].
x
,
from
[
4
*
stride
].
x
,
from
[
3
*
stride
].
x
,
from
[
2
*
stride
].
x
,
from
[
1
*
stride
].
x
,
from
[
0
*
stride
].
x
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pgather
<
Eigen
::
half
,
Packet8h
>
(
const
Eigen
::
half
*
from
,
Index
stride
)
{
return
_mm_set_ph
(
from
[
7
*
stride
].
x
,
from
[
6
*
stride
].
x
,
from
[
5
*
stride
].
x
,
from
[
4
*
stride
].
x
,
from
[
3
*
stride
].
x
,
from
[
2
*
stride
].
x
,
from
[
1
*
stride
].
x
,
from
[
0
*
stride
].
x
);
}
}
// end namespace internal
}
// end namespace Eigen
#endif // EIGEN_PACKET_MATH_FP16_AVX512_H
eigen-master/Eigen/src/Core/arch/AVX512/TrsmKernel.h
0 → 100644
View file @
266d4fd9
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2022 Intel Corporation
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
#define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
#if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
#define EIGEN_USE_AVX512_TRSM_KERNELS 1
#endif
// TRSM kernels currently unconditionally rely on malloc with AVX512.
// Disable them if malloc is explicitly disabled at compile-time.
#ifdef EIGEN_NO_MALLOC
#undef EIGEN_USE_AVX512_TRSM_KERNELS
#define EIGEN_USE_AVX512_TRSM_KERNELS 0
#endif
#if EIGEN_USE_AVX512_TRSM_KERNELS
#if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
#endif
#if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
#endif
#else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
#endif
// Need this for some std::min calls.
#ifdef min
#undef min
#endif
namespace
Eigen
{
namespace
internal
{
#define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
#define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code.
#define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
#define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
#define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
typedef
Packet16f
vecFullFloat
;
typedef
Packet8d
vecFullDouble
;
typedef
Packet8f
vecHalfFloat
;
typedef
Packet4d
vecHalfDouble
;
// Compile-time unrolls are implemented here.
// Note: this depends on macros and typedefs above.
#include "TrsmUnrolls.inc"
#if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
/**
* For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
* is faster than the packed versions in TriangularSolverMatrix.h.
*
* The current heuristic is based on having having all arrays used in the largest gemm-update
* in triSolve fit in roughly L2Cap (percentage) of the L2 cache. These cutoffs are a bit conservative and could be
* larger for some trsm cases.
* The formula:
*
* (L*M + M*N + L*N)*sizeof(Scalar) < L2Cache*L2Cap
*
* L = number of rows to solve at a time
* N = number of rhs
* M = Dimension of triangular matrix
*
*/
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
#endif
#if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
#if EIGEN_USE_AVX512_TRSM_R_KERNELS
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
#endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
#endif
#if EIGEN_USE_AVX512_TRSM_L_KERNELS
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
#endif
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
#else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
#endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
template
<
typename
Scalar
>
int64_t
avx512_trsm_cutoff
(
int64_t
L2Size
,
int64_t
N
,
double
L2Cap
)
{
const
int64_t
U3
=
3
*
packet_traits
<
Scalar
>::
size
;
const
int64_t
MaxNb
=
5
*
U3
;
int64_t
Nb
=
std
::
min
(
MaxNb
,
N
);
double
cutoff_d
=
(((
L2Size
*
L2Cap
)
/
(
sizeof
(
Scalar
)))
-
(
EIGEN_AVX_MAX_NUM_ROW
)
*
Nb
)
/
((
EIGEN_AVX_MAX_NUM_ROW
)
+
Nb
);
int64_t
cutoff_l
=
static_cast
<
int64_t
>
(
cutoff_d
);
return
(
cutoff_l
/
EIGEN_AVX_MAX_NUM_ROW
)
*
EIGEN_AVX_MAX_NUM_ROW
;
}
#else // !(EIGEN_USE_AVX512_TRSM_KERNELS) || !(EIGEN_COMP_CLANG != 0)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
#endif
/**
* Used by gemmKernel for the case A/B row-major and C col-major.
*/
template
<
typename
Scalar
,
typename
vec
,
int64_t
unrollM
,
int64_t
unrollN
,
bool
remM
,
bool
remN
>
EIGEN_ALWAYS_INLINE
void
transStoreC
(
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
Scalar
*
C_arr
,
int64_t
LDC
,
int64_t
remM_
=
0
,
int64_t
remN_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
remN_
);
EIGEN_UNUSED_VARIABLE
(
remM_
);
using
urolls
=
unrolls
::
trans
<
Scalar
>
;
constexpr
int64_t
U3
=
urolls
::
PacketSize
*
3
;
constexpr
int64_t
U2
=
urolls
::
PacketSize
*
2
;
constexpr
int64_t
U1
=
urolls
::
PacketSize
*
1
;
static_assert
(
unrollN
==
U1
||
unrollN
==
U2
||
unrollN
==
U3
,
"unrollN should be a multiple of PacketSize"
);
static_assert
(
unrollM
==
EIGEN_AVX_MAX_NUM_ROW
,
"unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW"
);
urolls
::
template
transpose
<
unrollN
,
0
>(
zmm
);
EIGEN_IF_CONSTEXPR
(
unrollN
>
U2
)
urolls
::
template
transpose
<
unrollN
,
2
>(
zmm
);
EIGEN_IF_CONSTEXPR
(
unrollN
>
U1
)
urolls
::
template
transpose
<
unrollN
,
1
>(
zmm
);
static_assert
((
remN
&&
unrollN
==
U1
)
||
!
remN
,
"When handling N remainder set unrollN=U1"
);
EIGEN_IF_CONSTEXPR
(
!
remN
)
{
urolls
::
template
storeC
<
std
::
min
(
unrollN
,
U1
),
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
EIGEN_IF_CONSTEXPR
(
unrollN
>
U1
)
{
constexpr
int64_t
unrollN_
=
std
::
min
(
unrollN
-
U1
,
U1
);
urolls
::
template
storeC
<
unrollN_
,
unrollN
,
1
,
remM
>(
C_arr
+
U1
*
LDC
,
LDC
,
zmm
,
remM_
);
}
EIGEN_IF_CONSTEXPR
(
unrollN
>
U2
)
{
constexpr
int64_t
unrollN_
=
std
::
min
(
unrollN
-
U2
,
U1
);
urolls
::
template
storeC
<
unrollN_
,
unrollN
,
2
,
remM
>(
C_arr
+
U2
*
LDC
,
LDC
,
zmm
,
remM_
);
}
}
else
{
EIGEN_IF_CONSTEXPR
((
std
::
is_same
<
Scalar
,
float
>::
value
))
{
// Note: without "if constexpr" this section of code will also be
// parsed by the compiler so each of the storeC will still be instantiated.
// We use enable_if in aux_storeC to set it to an empty function for
// these cases.
if
(
remN_
==
15
)
urolls
::
template
storeC
<
15
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
14
)
urolls
::
template
storeC
<
14
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
13
)
urolls
::
template
storeC
<
13
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
12
)
urolls
::
template
storeC
<
12
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
11
)
urolls
::
template
storeC
<
11
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
10
)
urolls
::
template
storeC
<
10
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
9
)
urolls
::
template
storeC
<
9
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
8
)
urolls
::
template
storeC
<
8
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
7
)
urolls
::
template
storeC
<
7
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
6
)
urolls
::
template
storeC
<
6
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
5
)
urolls
::
template
storeC
<
5
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
4
)
urolls
::
template
storeC
<
4
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
3
)
urolls
::
template
storeC
<
3
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
2
)
urolls
::
template
storeC
<
2
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
1
)
urolls
::
template
storeC
<
1
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
}
else
{
if
(
remN_
==
7
)
urolls
::
template
storeC
<
7
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
6
)
urolls
::
template
storeC
<
6
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
5
)
urolls
::
template
storeC
<
5
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
4
)
urolls
::
template
storeC
<
4
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
3
)
urolls
::
template
storeC
<
3
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
2
)
urolls
::
template
storeC
<
2
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
else
if
(
remN_
==
1
)
urolls
::
template
storeC
<
1
,
unrollN
,
0
,
remM
>(
C_arr
,
LDC
,
zmm
,
remM_
);
}
}
}
/**
* GEMM like operation for trsm panel updates.
* Computes: C -= A*B
* K must be multiple of 4.
*
* Unrolls used are {1,2,4,8}x{U1,U2,U3};
* For good performance we want K to be large with M/N relatively small, but also large enough
* to use the {8,U3} unroll block.
*
* isARowMajor: is A_arr row-major?
* isCRowMajor: is C_arr row-major? (B_arr is assumed to be row-major).
* isAdd: C += A*B or C -= A*B (used by trsm)
* handleKRem: Handle arbitrary K? This is not needed for trsm.
*/
template
<
typename
Scalar
,
bool
isARowMajor
,
bool
isCRowMajor
,
bool
isAdd
,
bool
handleKRem
>
void
gemmKernel
(
Scalar
*
A_arr
,
Scalar
*
B_arr
,
Scalar
*
C_arr
,
int64_t
M
,
int64_t
N
,
int64_t
K
,
int64_t
LDA
,
int64_t
LDB
,
int64_t
LDC
)
{
using
urolls
=
unrolls
::
gemm
<
Scalar
,
isAdd
>
;
constexpr
int64_t
U3
=
urolls
::
PacketSize
*
3
;
constexpr
int64_t
U2
=
urolls
::
PacketSize
*
2
;
constexpr
int64_t
U1
=
urolls
::
PacketSize
*
1
;
using
vec
=
typename
std
::
conditional
<
std
::
is_same
<
Scalar
,
float
>::
value
,
vecFullFloat
,
vecFullDouble
>::
type
;
int64_t
N_
=
(
N
/
U3
)
*
U3
;
int64_t
M_
=
(
M
/
EIGEN_AVX_MAX_NUM_ROW
)
*
EIGEN_AVX_MAX_NUM_ROW
;
int64_t
K_
=
(
K
/
EIGEN_AVX_MAX_K_UNROL
)
*
EIGEN_AVX_MAX_K_UNROL
;
int64_t
j
=
0
;
for
(;
j
<
N_
;
j
+=
U3
)
{
constexpr
int64_t
EIGEN_AVX_MAX_B_LOAD
=
EIGEN_AVX_B_LOAD_SETS
*
3
;
int64_t
i
=
0
;
for
(;
i
<
M_
;
i
+=
EIGEN_AVX_MAX_NUM_ROW
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)],
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
3
,
EIGEN_AVX_MAX_NUM_ROW
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
3
,
EIGEN_AVX_MAX_NUM_ROW
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
3
,
EIGEN_AVX_MAX_NUM_ROW
,
1
,
EIGEN_AVX_B_LOAD_SETS
*
3
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
3
,
EIGEN_AVX_MAX_NUM_ROW
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
3
,
EIGEN_AVX_MAX_NUM_ROW
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U3
,
false
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
);
}
}
if
(
M
-
i
>=
4
)
{
// Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
3
,
4
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
3
,
4
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_B_LOAD_SETS
*
3
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
3
,
4
,
1
,
EIGEN_AVX_B_LOAD_SETS
*
3
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
3
,
4
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
3
,
4
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U3
,
true
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
4
);
}
i
+=
4
;
}
if
(
M
-
i
>=
2
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
3
,
2
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
3
,
2
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_B_LOAD_SETS
*
3
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
3
,
2
,
1
,
EIGEN_AVX_B_LOAD_SETS
*
3
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
3
,
2
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
3
,
2
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U3
,
true
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
2
);
}
i
+=
2
;
}
if
(
M
-
i
>
0
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
3
,
1
>(
zmm
);
{
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
3
,
1
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_B_LOAD_SETS
*
3
,
1
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
3
,
1
,
1
,
EIGEN_AVX_B_LOAD_SETS
*
3
,
1
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
3
,
1
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
3
,
1
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U3
,
true
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
1
);
}
}
}
}
if
(
N
-
j
>=
U2
)
{
constexpr
int64_t
EIGEN_AVX_MAX_B_LOAD
=
EIGEN_AVX_B_LOAD_SETS
*
2
;
int64_t
i
=
0
;
for
(;
i
<
M_
;
i
+=
EIGEN_AVX_MAX_NUM_ROW
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)],
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
2
,
EIGEN_AVX_MAX_NUM_ROW
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
2
,
EIGEN_AVX_MAX_NUM_ROW
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
2
,
EIGEN_AVX_MAX_NUM_ROW
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
2
,
EIGEN_AVX_MAX_NUM_ROW
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
2
,
EIGEN_AVX_MAX_NUM_ROW
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U2
,
false
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
);
}
}
if
(
M
-
i
>=
4
)
{
// Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
2
,
4
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
2
,
4
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
2
,
4
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
2
,
4
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
2
,
4
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U2
,
true
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
4
);
}
i
+=
4
;
}
if
(
M
-
i
>=
2
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
2
,
2
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
2
,
2
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
2
,
2
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
2
,
2
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
2
,
2
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U2
,
true
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
2
);
}
i
+=
2
;
}
if
(
M
-
i
>
0
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
2
,
1
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
2
,
1
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
1
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
2
,
1
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
1
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
2
,
1
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
2
,
1
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U2
,
true
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
1
);
}
}
j
+=
U2
;
}
if
(
N
-
j
>=
U1
)
{
constexpr
int64_t
EIGEN_AVX_MAX_B_LOAD
=
EIGEN_AVX_B_LOAD_SETS
*
1
;
int64_t
i
=
0
;
for
(;
i
<
M_
;
i
+=
EIGEN_AVX_MAX_NUM_ROW
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)],
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
1
,
EIGEN_AVX_MAX_NUM_ROW
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
EIGEN_AVX_MAX_NUM_ROW
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
EIGEN_AVX_MAX_NUM_ROW
,
1
,
EIGEN_AVX_B_LOAD_SETS
*
1
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
1
,
EIGEN_AVX_MAX_NUM_ROW
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
1
,
EIGEN_AVX_MAX_NUM_ROW
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U1
,
false
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
);
}
}
if
(
M
-
i
>=
4
)
{
// Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
1
,
4
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
4
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
4
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
1
,
4
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
1
,
4
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U1
,
true
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
4
);
}
i
+=
4
;
}
if
(
M
-
i
>=
2
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
1
,
2
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
2
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
2
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
1
,
2
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
1
,
2
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U1
,
true
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
2
);
}
i
+=
2
;
}
if
(
M
-
i
>
0
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
1
,
1
>(
zmm
);
{
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
1
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
1
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
1
,
1
,
EIGEN_AVX_B_LOAD_SETS
*
1
,
1
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
1
,
1
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
urolls
::
template
storeC
<
1
,
1
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U1
,
true
,
false
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
1
);
}
}
}
j
+=
U1
;
}
if
(
N
-
j
>
0
)
{
constexpr
int64_t
EIGEN_AVX_MAX_B_LOAD
=
EIGEN_AVX_B_LOAD_SETS
*
1
;
int64_t
i
=
0
;
for
(;
i
<
M_
;
i
+=
EIGEN_AVX_MAX_NUM_ROW
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
1
,
EIGEN_AVX_MAX_NUM_ROW
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
EIGEN_AVX_MAX_NUM_ROW
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
,
true
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
N
-
j
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
EIGEN_AVX_MAX_NUM_ROW
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
,
true
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
N
-
j
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
1
,
EIGEN_AVX_MAX_NUM_ROW
,
true
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
,
N
-
j
);
urolls
::
template
storeC
<
1
,
EIGEN_AVX_MAX_NUM_ROW
,
true
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
,
N
-
j
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U1
,
false
,
true
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
0
,
N
-
j
);
}
}
if
(
M
-
i
>=
4
)
{
// Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
1
,
4
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
4
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
,
true
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
N
-
j
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
4
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
,
true
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
N
-
j
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
1
,
4
,
true
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
,
N
-
j
);
urolls
::
template
storeC
<
1
,
4
,
true
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
,
N
-
j
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U1
,
true
,
true
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
4
,
N
-
j
);
}
i
+=
4
;
}
if
(
M
-
i
>=
2
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
1
,
2
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
2
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
,
true
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
N
-
j
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
2
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
EIGEN_AVX_MAX_A_BCAST
,
true
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
N
-
j
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
1
,
2
,
true
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
,
N
-
j
);
urolls
::
template
storeC
<
1
,
2
,
true
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
,
N
-
j
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U1
,
true
,
true
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
2
,
N
-
j
);
}
i
+=
2
;
}
if
(
M
-
i
>
0
)
{
Scalar
*
A_t
=
&
A_arr
[
idA
<
isARowMajor
>
(
i
,
0
,
LDA
)];
Scalar
*
B_t
=
&
B_arr
[
0
*
LDB
+
j
];
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
zmm
;
urolls
::
template
setzero
<
1
,
1
>(
zmm
);
for
(
int64_t
k
=
0
;
k
<
K_
;
k
+=
EIGEN_AVX_MAX_K_UNROL
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
1
,
EIGEN_AVX_MAX_K_UNROL
,
EIGEN_AVX_MAX_B_LOAD
,
1
,
true
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
N
-
j
);
B_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
+=
EIGEN_AVX_MAX_K_UNROL
;
else
A_t
+=
EIGEN_AVX_MAX_K_UNROL
*
LDA
;
}
EIGEN_IF_CONSTEXPR
(
handleKRem
)
{
for
(
int64_t
k
=
K_
;
k
<
K
;
k
++
)
{
urolls
::
template
microKernel
<
isARowMajor
,
1
,
1
,
1
,
EIGEN_AVX_MAX_B_LOAD
,
1
,
true
>(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
N
-
j
);
B_t
+=
LDB
;
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
A_t
++
;
else
A_t
+=
LDA
;
}
}
EIGEN_IF_CONSTEXPR
(
isCRowMajor
)
{
urolls
::
template
updateC
<
1
,
1
,
true
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
,
N
-
j
);
urolls
::
template
storeC
<
1
,
1
,
true
>(
&
C_arr
[
i
*
LDC
+
j
],
LDC
,
zmm
,
N
-
j
);
}
else
{
transStoreC
<
Scalar
,
vec
,
EIGEN_AVX_MAX_NUM_ROW
,
U1
,
true
,
true
>
(
zmm
,
&
C_arr
[
i
+
j
*
LDC
],
LDC
,
1
,
N
-
j
);
}
}
}
}
/**
* Triangular solve kernel with A on left with K number of rhs. dim(A) = unrollM
*
* unrollM: dimension of A matrix (triangular matrix). unrollM should be <= EIGEN_AVX_MAX_NUM_ROW
* isFWDSolve: is forward solve?
* isUnitDiag: is the diagonal of A all ones?
* The B matrix (RHS) is assumed to be row-major
*/
template
<
typename
Scalar
,
typename
vec
,
int64_t
unrollM
,
bool
isARowMajor
,
bool
isFWDSolve
,
bool
isUnitDiag
>
EIGEN_ALWAYS_INLINE
void
triSolveKernel
(
Scalar
*
A_arr
,
Scalar
*
B_arr
,
int64_t
K
,
int64_t
LDA
,
int64_t
LDB
)
{
static_assert
(
unrollM
<=
EIGEN_AVX_MAX_NUM_ROW
,
"unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW"
);
using
urolls
=
unrolls
::
trsm
<
Scalar
>
;
constexpr
int64_t
U3
=
urolls
::
PacketSize
*
3
;
constexpr
int64_t
U2
=
urolls
::
PacketSize
*
2
;
constexpr
int64_t
U1
=
urolls
::
PacketSize
*
1
;
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
RHSInPacket
;
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
AInPacket
;
int64_t
k
=
0
;
while
(
K
-
k
>=
U3
)
{
urolls
::
template
loadRHS
<
isFWDSolve
,
unrollM
,
3
>(
B_arr
+
k
,
LDB
,
RHSInPacket
);
urolls
::
template
triSolveMicroKernel
<
isARowMajor
,
isFWDSolve
,
isUnitDiag
,
unrollM
,
3
>(
A_arr
,
LDA
,
RHSInPacket
,
AInPacket
);
urolls
::
template
storeRHS
<
isFWDSolve
,
unrollM
,
3
>(
B_arr
+
k
,
LDB
,
RHSInPacket
);
k
+=
U3
;
}
if
(
K
-
k
>=
U2
)
{
urolls
::
template
loadRHS
<
isFWDSolve
,
unrollM
,
2
>(
B_arr
+
k
,
LDB
,
RHSInPacket
);
urolls
::
template
triSolveMicroKernel
<
isARowMajor
,
isFWDSolve
,
isUnitDiag
,
unrollM
,
2
>(
A_arr
,
LDA
,
RHSInPacket
,
AInPacket
);
urolls
::
template
storeRHS
<
isFWDSolve
,
unrollM
,
2
>(
B_arr
+
k
,
LDB
,
RHSInPacket
);
k
+=
U2
;
}
if
(
K
-
k
>=
U1
)
{
urolls
::
template
loadRHS
<
isFWDSolve
,
unrollM
,
1
>(
B_arr
+
k
,
LDB
,
RHSInPacket
);
urolls
::
template
triSolveMicroKernel
<
isARowMajor
,
isFWDSolve
,
isUnitDiag
,
unrollM
,
1
>(
A_arr
,
LDA
,
RHSInPacket
,
AInPacket
);
urolls
::
template
storeRHS
<
isFWDSolve
,
unrollM
,
1
>(
B_arr
+
k
,
LDB
,
RHSInPacket
);
k
+=
U1
;
}
if
(
K
-
k
>
0
)
{
// Handle remaining number of RHS
urolls
::
template
loadRHS
<
isFWDSolve
,
unrollM
,
1
,
true
>(
B_arr
+
k
,
LDB
,
RHSInPacket
,
K
-
k
);
urolls
::
template
triSolveMicroKernel
<
isARowMajor
,
isFWDSolve
,
isUnitDiag
,
unrollM
,
1
>(
A_arr
,
LDA
,
RHSInPacket
,
AInPacket
);
urolls
::
template
storeRHS
<
isFWDSolve
,
unrollM
,
1
,
true
>(
B_arr
+
k
,
LDB
,
RHSInPacket
,
K
-
k
);
}
}
/**
* Triangular solve routine with A on left and dimension of at most L with K number of rhs. This is essentially
* a wrapper for triSolveMicrokernel for M = {1,2,3,4,5,6,7,8}.
*
* isFWDSolve: is forward solve?
* isUnitDiag: is the diagonal of A all ones?
* The B matrix (RHS) is assumed to be row-major
*/
template
<
typename
Scalar
,
bool
isARowMajor
,
bool
isFWDSolve
,
bool
isUnitDiag
>
void
triSolveKernelLxK
(
Scalar
*
A_arr
,
Scalar
*
B_arr
,
int64_t
M
,
int64_t
K
,
int64_t
LDA
,
int64_t
LDB
)
{
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
using
vec
=
typename
std
::
conditional
<
std
::
is_same
<
Scalar
,
float
>::
value
,
vecFullFloat
,
vecFullDouble
>::
type
;
if
(
M
==
8
)
triSolveKernel
<
Scalar
,
vec
,
8
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
A_arr
,
B_arr
,
K
,
LDA
,
LDB
);
else
if
(
M
==
7
)
triSolveKernel
<
Scalar
,
vec
,
7
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
A_arr
,
B_arr
,
K
,
LDA
,
LDB
);
else
if
(
M
==
6
)
triSolveKernel
<
Scalar
,
vec
,
6
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
A_arr
,
B_arr
,
K
,
LDA
,
LDB
);
else
if
(
M
==
5
)
triSolveKernel
<
Scalar
,
vec
,
5
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
A_arr
,
B_arr
,
K
,
LDA
,
LDB
);
else
if
(
M
==
4
)
triSolveKernel
<
Scalar
,
vec
,
4
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
A_arr
,
B_arr
,
K
,
LDA
,
LDB
);
else
if
(
M
==
3
)
triSolveKernel
<
Scalar
,
vec
,
3
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
A_arr
,
B_arr
,
K
,
LDA
,
LDB
);
else
if
(
M
==
2
)
triSolveKernel
<
Scalar
,
vec
,
2
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
A_arr
,
B_arr
,
K
,
LDA
,
LDB
);
else
if
(
M
==
1
)
triSolveKernel
<
Scalar
,
vec
,
1
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
A_arr
,
B_arr
,
K
,
LDA
,
LDB
);
return
;
}
/**
* This routine is used to copy B to/from a temporary array (row-major) for cases where B is column-major.
*
* toTemp: true => copy to temporary array, false => copy from temporary array
* remM: true = need to handle remainder values for M (M < EIGEN_AVX_MAX_NUM_ROW)
*
*/
template
<
typename
Scalar
,
bool
toTemp
=
true
,
bool
remM
=
false
>
EIGEN_ALWAYS_INLINE
void
copyBToRowMajor
(
Scalar
*
B_arr
,
int64_t
LDB
,
int64_t
K
,
Scalar
*
B_temp
,
int64_t
LDB_
,
int64_t
remM_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
remM_
);
using
urolls
=
unrolls
::
transB
<
Scalar
>
;
using
vecHalf
=
typename
std
::
conditional
<
std
::
is_same
<
Scalar
,
float
>::
value
,
vecHalfFloat
,
vecFullDouble
>::
type
;
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
ymm
;
constexpr
int64_t
U3
=
urolls
::
PacketSize
*
3
;
constexpr
int64_t
U2
=
urolls
::
PacketSize
*
2
;
constexpr
int64_t
U1
=
urolls
::
PacketSize
*
1
;
int64_t
K_
=
K
/
U3
*
U3
;
int64_t
k
=
0
;
for
(;
k
<
K_
;
k
+=
U3
)
{
urolls
::
template
transB_kernel
<
U3
,
toTemp
,
remM
>(
B_arr
+
k
*
LDB
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
B_temp
+=
U3
;
}
if
(
K
-
k
>=
U2
)
{
urolls
::
template
transB_kernel
<
U2
,
toTemp
,
remM
>(
B_arr
+
k
*
LDB
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
B_temp
+=
U2
;
k
+=
U2
;
}
if
(
K
-
k
>=
U1
)
{
urolls
::
template
transB_kernel
<
U1
,
toTemp
,
remM
>(
B_arr
+
k
*
LDB
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
B_temp
+=
U1
;
k
+=
U1
;
}
EIGEN_IF_CONSTEXPR
(
U1
>
8
)
{
// Note: without "if constexpr" this section of code will also be
// parsed by the compiler so there is an additional check in {load/store}BBlock
// to make sure the counter is not non-negative.
if
(
K
-
k
>=
8
)
{
urolls
::
template
transB_kernel
<
8
,
toTemp
,
remM
>(
B_arr
+
k
*
LDB
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
B_temp
+=
8
;
k
+=
8
;
}
}
EIGEN_IF_CONSTEXPR
(
U1
>
4
)
{
// Note: without "if constexpr" this section of code will also be
// parsed by the compiler so there is an additional check in {load/store}BBlock
// to make sure the counter is not non-negative.
if
(
K
-
k
>=
4
)
{
urolls
::
template
transB_kernel
<
4
,
toTemp
,
remM
>(
B_arr
+
k
*
LDB
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
B_temp
+=
4
;
k
+=
4
;
}
}
if
(
K
-
k
>=
2
)
{
urolls
::
template
transB_kernel
<
2
,
toTemp
,
remM
>(
B_arr
+
k
*
LDB
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
B_temp
+=
2
;
k
+=
2
;
}
if
(
K
-
k
>=
1
)
{
urolls
::
template
transB_kernel
<
1
,
toTemp
,
remM
>(
B_arr
+
k
*
LDB
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
B_temp
+=
1
;
k
+=
1
;
}
}
/**
* Main triangular solve driver
*
* Triangular solve with A on the left.
* Scalar: Scalar precision, only float/double is supported.
* isARowMajor: is A row-major?
* isBRowMajor: is B row-major?
* isFWDSolve: is this forward solve or backward (true => forward)?
* isUnitDiag: is diagonal of A unit or nonunit (true => A has unit diagonal)?
*
* M: dimension of A
* numRHS: number of right hand sides (coincides with K dimension for gemm updates)
*
* Here are the mapping between the different TRSM cases (col-major) and triSolve:
*
* LLN (left , lower, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=true
* LUT (left , upper, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=true
* LUN (left , upper, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=false
* LLT (left , lower, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=false
* RUN (right, upper, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=true
* RLT (right, lower, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=true
* RUT (right, upper, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=false
* RLN (right, lower, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=false
*
* Note: For RXX cases M,numRHS should be swapped.
*
*/
template
<
typename
Scalar
,
bool
isARowMajor
=
true
,
bool
isBRowMajor
=
true
,
bool
isFWDSolve
=
true
,
bool
isUnitDiag
=
false
>
void
triSolve
(
Scalar
*
A_arr
,
Scalar
*
B_arr
,
int64_t
M
,
int64_t
numRHS
,
int64_t
LDA
,
int64_t
LDB
)
{
constexpr
int64_t
psize
=
packet_traits
<
Scalar
>::
size
;
/**
* The values for kB, numM were determined experimentally.
* kB: Number of RHS we process at a time.
* numM: number of rows of B we will store in a temporary array (see below.) This should be a multiple of L.
*
* kB was determined by initially setting kB = numRHS and benchmarking triSolve (TRSM-RUN case)
* performance with M=numRHS.
* It was observed that performance started to drop around M=numRHS=240. This is likely machine dependent.
*
* numM was chosen "arbitrarily". It should be relatively small so B_temp is not too large, but it should be
* large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
* determine optimal values for numM.
*/
constexpr
int64_t
kB
=
(
3
*
psize
)
*
5
;
// 5*U3
constexpr
int64_t
numM
=
8
*
EIGEN_AVX_MAX_NUM_ROW
;
int64_t
sizeBTemp
=
0
;
Scalar
*
B_temp
=
NULL
;
EIGEN_IF_CONSTEXPR
(
!
isBRowMajor
)
{
/**
* If B is col-major, we copy it to a fixed-size temporary array of size at most ~numM*kB and
* transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
* The updated row-major copy of B is reused in the GEMM updates.
*/
sizeBTemp
=
(((
std
::
min
(
kB
,
numRHS
)
+
psize
-
1
)
/
psize
+
4
)
*
psize
)
*
numM
;
}
EIGEN_IF_CONSTEXPR
(
!
isBRowMajor
)
B_temp
=
(
Scalar
*
)
handmade_aligned_malloc
(
sizeof
(
Scalar
)
*
sizeBTemp
,
64
);
for
(
int64_t
k
=
0
;
k
<
numRHS
;
k
+=
kB
)
{
int64_t
bK
=
numRHS
-
k
>
kB
?
kB
:
numRHS
-
k
;
int64_t
M_
=
(
M
/
EIGEN_AVX_MAX_NUM_ROW
)
*
EIGEN_AVX_MAX_NUM_ROW
,
gemmOff
=
0
;
// bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS
// instead of bK RHS in triSolveKernelLxK.
int64_t
bkL
=
((
bK
+
(
EIGEN_AVX_MAX_NUM_ROW
-
1
))
/
EIGEN_AVX_MAX_NUM_ROW
)
*
EIGEN_AVX_MAX_NUM_ROW
;
const
int64_t
numScalarPerCache
=
64
/
sizeof
(
Scalar
);
// Leading dimension of B_temp, will be a multiple of the cache line size.
int64_t
LDT
=
((
bkL
+
(
numScalarPerCache
-
1
))
/
numScalarPerCache
)
*
numScalarPerCache
;
int64_t
offsetBTemp
=
0
;
for
(
int64_t
i
=
0
;
i
<
M_
;
i
+=
EIGEN_AVX_MAX_NUM_ROW
)
{
EIGEN_IF_CONSTEXPR
(
!
isBRowMajor
)
{
int64_t
indA_i
=
isFWDSolve
?
i
:
M
-
1
-
i
;
int64_t
indB_i
=
isFWDSolve
?
i
:
M
-
(
i
+
EIGEN_AVX_MAX_NUM_ROW
);
int64_t
offB_1
=
isFWDSolve
?
offsetBTemp
:
sizeBTemp
-
EIGEN_AVX_MAX_NUM_ROW
*
LDT
-
offsetBTemp
;
int64_t
offB_2
=
isFWDSolve
?
offsetBTemp
:
sizeBTemp
-
LDT
-
offsetBTemp
;
// Copy values from B to B_temp.
copyBToRowMajor
<
Scalar
,
true
,
false
>
(
B_arr
+
indB_i
+
k
*
LDB
,
LDB
,
bK
,
B_temp
+
offB_1
,
LDT
);
// Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major)
triSolveKernelLxK
<
Scalar
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
&
A_arr
[
idA
<
isARowMajor
>
(
indA_i
,
indA_i
,
LDA
)],
B_temp
+
offB_2
,
EIGEN_AVX_MAX_NUM_ROW
,
bkL
,
LDA
,
LDT
);
// Copy values from B_temp back to B. B_temp will be reused in gemm call below.
copyBToRowMajor
<
Scalar
,
false
,
false
>
(
B_arr
+
indB_i
+
k
*
LDB
,
LDB
,
bK
,
B_temp
+
offB_1
,
LDT
);
offsetBTemp
+=
EIGEN_AVX_MAX_NUM_ROW
*
LDT
;
}
else
{
int64_t
ind
=
isFWDSolve
?
i
:
M
-
1
-
i
;
triSolveKernelLxK
<
Scalar
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
&
A_arr
[
idA
<
isARowMajor
>
(
ind
,
ind
,
LDA
)],
B_arr
+
k
+
ind
*
LDB
,
EIGEN_AVX_MAX_NUM_ROW
,
bK
,
LDA
,
LDB
);
}
if
(
i
+
EIGEN_AVX_MAX_NUM_ROW
<
M_
)
{
/**
* For the GEMM updates, we want "K" (K=i+8 in this case) to be large as soon as possible
* to reuse the accumulators in GEMM as much as possible. So we only update 8xbK blocks of
* B as follows:
*
* A B
* __
* |__|__ |__|
* |__|__|__ |__|
* |__|__|__|__ |__|
* |********|__| |**|
*/
EIGEN_IF_CONSTEXPR
(
isBRowMajor
)
{
int64_t
indA_i
=
isFWDSolve
?
i
+
EIGEN_AVX_MAX_NUM_ROW
:
M
-
(
i
+
2
*
EIGEN_AVX_MAX_NUM_ROW
);
int64_t
indA_j
=
isFWDSolve
?
0
:
M
-
(
i
+
EIGEN_AVX_MAX_NUM_ROW
);
int64_t
indB_i
=
isFWDSolve
?
0
:
M
-
(
i
+
EIGEN_AVX_MAX_NUM_ROW
);
int64_t
indB_i2
=
isFWDSolve
?
i
+
EIGEN_AVX_MAX_NUM_ROW
:
M
-
(
i
+
2
*
EIGEN_AVX_MAX_NUM_ROW
);
gemmKernel
<
Scalar
,
isARowMajor
,
isBRowMajor
,
false
,
false
>
(
&
A_arr
[
idA
<
isARowMajor
>
(
indA_i
,
indA_j
,
LDA
)],
B_arr
+
k
+
indB_i
*
LDB
,
B_arr
+
k
+
indB_i2
*
LDB
,
EIGEN_AVX_MAX_NUM_ROW
,
bK
,
i
+
EIGEN_AVX_MAX_NUM_ROW
,
LDA
,
LDB
,
LDB
);
}
else
{
if
(
offsetBTemp
+
EIGEN_AVX_MAX_NUM_ROW
*
LDT
>
sizeBTemp
)
{
/**
* Similar idea as mentioned above, but here we are limited by the number of updated values of B
* that can be stored (row-major) in B_temp.
*
* If there is not enough space to store the next batch of 8xbK of B in B_temp, we call GEMM
* update and partially update the remaining old values of B which depends on the new values
* of B stored in B_temp. These values are then no longer needed and can be overwritten.
*/
int64_t
indA_i
=
isFWDSolve
?
i
+
EIGEN_AVX_MAX_NUM_ROW
:
0
;
int64_t
indA_j
=
isFWDSolve
?
gemmOff
:
M
-
(
i
+
EIGEN_AVX_MAX_NUM_ROW
);
int64_t
indB_i
=
isFWDSolve
?
i
+
EIGEN_AVX_MAX_NUM_ROW
:
0
;
int64_t
offB_1
=
isFWDSolve
?
0
:
sizeBTemp
-
offsetBTemp
;
gemmKernel
<
Scalar
,
isARowMajor
,
isBRowMajor
,
false
,
false
>
(
&
A_arr
[
idA
<
isARowMajor
>
(
indA_i
,
indA_j
,
LDA
)],
B_temp
+
offB_1
,
B_arr
+
indB_i
+
(
k
)
*
LDB
,
M
-
(
i
+
EIGEN_AVX_MAX_NUM_ROW
),
bK
,
i
+
EIGEN_AVX_MAX_NUM_ROW
-
gemmOff
,
LDA
,
LDT
,
LDB
);
offsetBTemp
=
0
;
gemmOff
=
i
+
EIGEN_AVX_MAX_NUM_ROW
;
}
else
{
/**
* If there is enough space in B_temp, we only update the next 8xbK values of B.
*/
int64_t
indA_i
=
isFWDSolve
?
i
+
EIGEN_AVX_MAX_NUM_ROW
:
M
-
(
i
+
2
*
EIGEN_AVX_MAX_NUM_ROW
);
int64_t
indA_j
=
isFWDSolve
?
gemmOff
:
M
-
(
i
+
EIGEN_AVX_MAX_NUM_ROW
);
int64_t
indB_i
=
isFWDSolve
?
i
+
EIGEN_AVX_MAX_NUM_ROW
:
M
-
(
i
+
2
*
EIGEN_AVX_MAX_NUM_ROW
);
int64_t
offB_1
=
isFWDSolve
?
0
:
sizeBTemp
-
offsetBTemp
;
gemmKernel
<
Scalar
,
isARowMajor
,
isBRowMajor
,
false
,
false
>
(
&
A_arr
[
idA
<
isARowMajor
>
(
indA_i
,
indA_j
,
LDA
)],
B_temp
+
offB_1
,
B_arr
+
indB_i
+
(
k
)
*
LDB
,
EIGEN_AVX_MAX_NUM_ROW
,
bK
,
i
+
EIGEN_AVX_MAX_NUM_ROW
-
gemmOff
,
LDA
,
LDT
,
LDB
);
}
}
}
}
// Handle M remainder..
int64_t
bM
=
M
-
M_
;
if
(
bM
>
0
)
{
if
(
M_
>
0
)
{
EIGEN_IF_CONSTEXPR
(
isBRowMajor
)
{
int64_t
indA_i
=
isFWDSolve
?
M_
:
0
;
int64_t
indA_j
=
isFWDSolve
?
0
:
bM
;
int64_t
indB_i
=
isFWDSolve
?
0
:
bM
;
int64_t
indB_i2
=
isFWDSolve
?
M_
:
0
;
gemmKernel
<
Scalar
,
isARowMajor
,
isBRowMajor
,
false
,
false
>
(
&
A_arr
[
idA
<
isARowMajor
>
(
indA_i
,
indA_j
,
LDA
)],
B_arr
+
k
+
indB_i
*
LDB
,
B_arr
+
k
+
indB_i2
*
LDB
,
bM
,
bK
,
M_
,
LDA
,
LDB
,
LDB
);
}
else
{
int64_t
indA_i
=
isFWDSolve
?
M_
:
0
;
int64_t
indA_j
=
isFWDSolve
?
gemmOff
:
bM
;
int64_t
indB_i
=
isFWDSolve
?
M_
:
0
;
int64_t
offB_1
=
isFWDSolve
?
0
:
sizeBTemp
-
offsetBTemp
;
gemmKernel
<
Scalar
,
isARowMajor
,
isBRowMajor
,
false
,
false
>
(
&
A_arr
[
idA
<
isARowMajor
>
(
indA_i
,
indA_j
,
LDA
)],
B_temp
+
offB_1
,
B_arr
+
indB_i
+
(
k
)
*
LDB
,
bM
,
bK
,
M_
-
gemmOff
,
LDA
,
LDT
,
LDB
);
}
}
EIGEN_IF_CONSTEXPR
(
!
isBRowMajor
)
{
int64_t
indA_i
=
isFWDSolve
?
M_
:
M
-
1
-
M_
;
int64_t
indB_i
=
isFWDSolve
?
M_
:
0
;
int64_t
offB_1
=
isFWDSolve
?
0
:
(
bM
-
1
)
*
bkL
;
copyBToRowMajor
<
Scalar
,
true
,
true
>
(
B_arr
+
indB_i
+
k
*
LDB
,
LDB
,
bK
,
B_temp
,
bkL
,
bM
);
triSolveKernelLxK
<
Scalar
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
&
A_arr
[
idA
<
isARowMajor
>
(
indA_i
,
indA_i
,
LDA
)],
B_temp
+
offB_1
,
bM
,
bkL
,
LDA
,
bkL
);
copyBToRowMajor
<
Scalar
,
false
,
true
>
(
B_arr
+
indB_i
+
k
*
LDB
,
LDB
,
bK
,
B_temp
,
bkL
,
bM
);
}
else
{
int64_t
ind
=
isFWDSolve
?
M_
:
M
-
1
-
M_
;
triSolveKernelLxK
<
Scalar
,
isARowMajor
,
isFWDSolve
,
isUnitDiag
>
(
&
A_arr
[
idA
<
isARowMajor
>
(
ind
,
ind
,
LDA
)],
B_arr
+
k
+
ind
*
LDB
,
bM
,
bK
,
LDA
,
LDB
);
}
}
}
EIGEN_IF_CONSTEXPR
(
!
isBRowMajor
)
handmade_aligned_free
(
B_temp
);
}
// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
#if (EIGEN_USE_AVX512_TRSM_KERNELS)
#if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
template
<
typename
Scalar
,
typename
Index
,
int
Mode
,
bool
Conjugate
,
int
TriStorageOrder
,
int
OtherInnerStride
,
bool
Specialized
>
struct
trsmKernelR
;
template
<
typename
Index
,
int
Mode
,
int
TriStorageOrder
>
struct
trsmKernelR
<
float
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
true
>
{
static
void
kernel
(
Index
size
,
Index
otherSize
,
const
float
*
_tri
,
Index
triStride
,
float
*
_other
,
Index
otherIncr
,
Index
otherStride
);
};
template
<
typename
Index
,
int
Mode
,
int
TriStorageOrder
>
struct
trsmKernelR
<
double
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
true
>
{
static
void
kernel
(
Index
size
,
Index
otherSize
,
const
double
*
_tri
,
Index
triStride
,
double
*
_other
,
Index
otherIncr
,
Index
otherStride
);
};
template
<
typename
Index
,
int
Mode
,
int
TriStorageOrder
>
EIGEN_DONT_INLINE
void
trsmKernelR
<
float
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
true
>::
kernel
(
Index
size
,
Index
otherSize
,
const
float
*
_tri
,
Index
triStride
,
float
*
_other
,
Index
otherIncr
,
Index
otherStride
)
{
EIGEN_UNUSED_VARIABLE
(
otherIncr
);
#ifdef EIGEN_RUNTIME_NO_MALLOC
if
(
!
is_malloc_allowed
())
{
trsmKernelR
<
float
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
/*Specialized=*/
false
>::
kernel
(
size
,
otherSize
,
_tri
,
triStride
,
_other
,
otherIncr
,
otherStride
);
return
;
}
#endif
triSolve
<
float
,
TriStorageOrder
!=
RowMajor
,
true
,
(
Mode
&
Lower
)
!=
Lower
,
(
Mode
&
UnitDiag
)
!=
0
>
(
const_cast
<
float
*>
(
_tri
),
_other
,
size
,
otherSize
,
triStride
,
otherStride
);
}
template
<
typename
Index
,
int
Mode
,
int
TriStorageOrder
>
EIGEN_DONT_INLINE
void
trsmKernelR
<
double
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
true
>::
kernel
(
Index
size
,
Index
otherSize
,
const
double
*
_tri
,
Index
triStride
,
double
*
_other
,
Index
otherIncr
,
Index
otherStride
)
{
EIGEN_UNUSED_VARIABLE
(
otherIncr
);
#ifdef EIGEN_RUNTIME_NO_MALLOC
if
(
!
is_malloc_allowed
())
{
trsmKernelR
<
double
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
/*Specialized=*/
false
>::
kernel
(
size
,
otherSize
,
_tri
,
triStride
,
_other
,
otherIncr
,
otherStride
);
return
;
}
#endif
triSolve
<
double
,
TriStorageOrder
!=
RowMajor
,
true
,
(
Mode
&
Lower
)
!=
Lower
,
(
Mode
&
UnitDiag
)
!=
0
>
(
const_cast
<
double
*>
(
_tri
),
_other
,
size
,
otherSize
,
triStride
,
otherStride
);
}
#endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
// These trsm kernels require temporary memory allocation
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
template
<
typename
Scalar
,
typename
Index
,
int
Mode
,
bool
Conjugate
,
int
TriStorageOrder
,
int
OtherInnerStride
,
bool
Specialized
=
true
>
struct
trsmKernelL
;
template
<
typename
Index
,
int
Mode
,
int
TriStorageOrder
>
struct
trsmKernelL
<
float
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
true
>
{
static
void
kernel
(
Index
size
,
Index
otherSize
,
const
float
*
_tri
,
Index
triStride
,
float
*
_other
,
Index
otherIncr
,
Index
otherStride
);
};
template
<
typename
Index
,
int
Mode
,
int
TriStorageOrder
>
struct
trsmKernelL
<
double
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
true
>
{
static
void
kernel
(
Index
size
,
Index
otherSize
,
const
double
*
_tri
,
Index
triStride
,
double
*
_other
,
Index
otherIncr
,
Index
otherStride
);
};
template
<
typename
Index
,
int
Mode
,
int
TriStorageOrder
>
EIGEN_DONT_INLINE
void
trsmKernelL
<
float
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
true
>::
kernel
(
Index
size
,
Index
otherSize
,
const
float
*
_tri
,
Index
triStride
,
float
*
_other
,
Index
otherIncr
,
Index
otherStride
)
{
EIGEN_UNUSED_VARIABLE
(
otherIncr
);
#ifdef EIGEN_RUNTIME_NO_MALLOC
if
(
!
is_malloc_allowed
())
{
trsmKernelL
<
float
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
/*Specialized=*/
false
>::
kernel
(
size
,
otherSize
,
_tri
,
triStride
,
_other
,
otherIncr
,
otherStride
);
return
;
}
#endif
triSolve
<
float
,
TriStorageOrder
==
RowMajor
,
false
,
(
Mode
&
Lower
)
==
Lower
,
(
Mode
&
UnitDiag
)
!=
0
>
(
const_cast
<
float
*>
(
_tri
),
_other
,
size
,
otherSize
,
triStride
,
otherStride
);
}
template
<
typename
Index
,
int
Mode
,
int
TriStorageOrder
>
EIGEN_DONT_INLINE
void
trsmKernelL
<
double
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
true
>::
kernel
(
Index
size
,
Index
otherSize
,
const
double
*
_tri
,
Index
triStride
,
double
*
_other
,
Index
otherIncr
,
Index
otherStride
)
{
EIGEN_UNUSED_VARIABLE
(
otherIncr
);
#ifdef EIGEN_RUNTIME_NO_MALLOC
if
(
!
is_malloc_allowed
())
{
trsmKernelL
<
double
,
Index
,
Mode
,
false
,
TriStorageOrder
,
1
,
/*Specialized=*/
false
>::
kernel
(
size
,
otherSize
,
_tri
,
triStride
,
_other
,
otherIncr
,
otherStride
);
return
;
}
#endif
triSolve
<
double
,
TriStorageOrder
==
RowMajor
,
false
,
(
Mode
&
Lower
)
==
Lower
,
(
Mode
&
UnitDiag
)
!=
0
>
(
const_cast
<
double
*>
(
_tri
),
_other
,
size
,
otherSize
,
triStride
,
otherStride
);
}
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
#endif // EIGEN_USE_AVX512_TRSM_KERNELS
}
// namespace internal
}
// namespace Eigen
#endif // EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
eigen-master/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc
0 → 100644
View file @
266d4fd9
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2022 Intel Corporation
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
template
<
bool
isARowMajor
=
true
>
EIGEN_ALWAYS_INLINE
int64_t
idA
(
int64_t
i
,
int64_t
j
,
int64_t
LDA
)
{
EIGEN_IF_CONSTEXPR
(
isARowMajor
)
return
i
*
LDA
+
j
;
else
return
i
+
j
*
LDA
;
}
/**
* This namespace contains various classes used to generate compile-time unrolls which are
* used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested
* for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion
*
* Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop.
*
* for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++)
* for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ)
* func(startI,startJ) startJ = (startC)%(endJ)
* func(...)
*
* The 1-D loop can be unrolled recursively by using enable_if and defining an auxiliary function
* with a template parameter used as a counter.
*
* template <endI, endJ, counter>
* std::enable_if_t<(counter <= 0)> <---- tail case.
* aux_func {}
*
* template <endI, endJ, counter>
* std::enable_if_t<(counter > 0)> <---- actual for-loop
* aux_func {
* startC = endI*endJ - counter
* startI = (startC)/(endJ)
* startJ = (startC)%(endJ)
* func(startI, startJ)
* aux_func<endI, endJ, counter-1>()
* }
*
* Note: Additional wrapper functions are provided for aux_func which hides the counter template
* parameter since counter usually depends on endI, endJ, etc...
*
* Conventions:
* 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++))
*
* 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for
* handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW)
*/
namespace
unrolls
{
template
<
int64_t
N
>
EIGEN_ALWAYS_INLINE
auto
remMask
(
int64_t
m
)
{
EIGEN_IF_CONSTEXPR
(
N
==
16
)
{
return
0xFFFF
>>
(
16
-
m
);
}
else
EIGEN_IF_CONSTEXPR
(
N
==
8
)
{
return
0xFF
>>
(
8
-
m
);
}
else
EIGEN_IF_CONSTEXPR
(
N
==
4
)
{
return
0x0F
>>
(
4
-
m
);
}
return
0
;
}
template
<
typename
Packet
>
EIGEN_ALWAYS_INLINE
void
trans8x8blocks
(
PacketBlock
<
Packet
,
8
>
&
kernel
);
template
<>
EIGEN_ALWAYS_INLINE
void
trans8x8blocks
(
PacketBlock
<
Packet16f
,
8
>
&
kernel
)
{
__m512
T0
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512
T1
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
0
],
kernel
.
packet
[
1
]);
__m512
T2
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512
T3
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
2
],
kernel
.
packet
[
3
]);
__m512
T4
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512
T5
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
4
],
kernel
.
packet
[
5
]);
__m512
T6
=
_mm512_unpacklo_ps
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
__m512
T7
=
_mm512_unpackhi_ps
(
kernel
.
packet
[
6
],
kernel
.
packet
[
7
]);
kernel
.
packet
[
0
]
=
_mm512_castpd_ps
(
_mm512_unpacklo_pd
(
_mm512_castps_pd
(
T0
),
_mm512_castps_pd
(
T2
)));
kernel
.
packet
[
1
]
=
_mm512_castpd_ps
(
_mm512_unpackhi_pd
(
_mm512_castps_pd
(
T0
),
_mm512_castps_pd
(
T2
)));
kernel
.
packet
[
2
]
=
_mm512_castpd_ps
(
_mm512_unpacklo_pd
(
_mm512_castps_pd
(
T1
),
_mm512_castps_pd
(
T3
)));
kernel
.
packet
[
3
]
=
_mm512_castpd_ps
(
_mm512_unpackhi_pd
(
_mm512_castps_pd
(
T1
),
_mm512_castps_pd
(
T3
)));
kernel
.
packet
[
4
]
=
_mm512_castpd_ps
(
_mm512_unpacklo_pd
(
_mm512_castps_pd
(
T4
),
_mm512_castps_pd
(
T6
)));
kernel
.
packet
[
5
]
=
_mm512_castpd_ps
(
_mm512_unpackhi_pd
(
_mm512_castps_pd
(
T4
),
_mm512_castps_pd
(
T6
)));
kernel
.
packet
[
6
]
=
_mm512_castpd_ps
(
_mm512_unpacklo_pd
(
_mm512_castps_pd
(
T5
),
_mm512_castps_pd
(
T7
)));
kernel
.
packet
[
7
]
=
_mm512_castpd_ps
(
_mm512_unpackhi_pd
(
_mm512_castps_pd
(
T5
),
_mm512_castps_pd
(
T7
)));
T0
=
_mm512_castpd_ps
(
_mm512_permutex_pd
(
_mm512_castps_pd
(
kernel
.
packet
[
4
]),
0x4E
));
T0
=
_mm512_mask_blend_ps
(
0xF0F0
,
kernel
.
packet
[
0
],
T0
);
T4
=
_mm512_castpd_ps
(
_mm512_permutex_pd
(
_mm512_castps_pd
(
kernel
.
packet
[
0
]),
0x4E
));
T4
=
_mm512_mask_blend_ps
(
0xF0F0
,
T4
,
kernel
.
packet
[
4
]);
T1
=
_mm512_castpd_ps
(
_mm512_permutex_pd
(
_mm512_castps_pd
(
kernel
.
packet
[
5
]),
0x4E
));
T1
=
_mm512_mask_blend_ps
(
0xF0F0
,
kernel
.
packet
[
1
],
T1
);
T5
=
_mm512_castpd_ps
(
_mm512_permutex_pd
(
_mm512_castps_pd
(
kernel
.
packet
[
1
]),
0x4E
));
T5
=
_mm512_mask_blend_ps
(
0xF0F0
,
T5
,
kernel
.
packet
[
5
]);
T2
=
_mm512_castpd_ps
(
_mm512_permutex_pd
(
_mm512_castps_pd
(
kernel
.
packet
[
6
]),
0x4E
));
T2
=
_mm512_mask_blend_ps
(
0xF0F0
,
kernel
.
packet
[
2
],
T2
);
T6
=
_mm512_castpd_ps
(
_mm512_permutex_pd
(
_mm512_castps_pd
(
kernel
.
packet
[
2
]),
0x4E
));
T6
=
_mm512_mask_blend_ps
(
0xF0F0
,
T6
,
kernel
.
packet
[
6
]);
T3
=
_mm512_castpd_ps
(
_mm512_permutex_pd
(
_mm512_castps_pd
(
kernel
.
packet
[
7
]),
0x4E
));
T3
=
_mm512_mask_blend_ps
(
0xF0F0
,
kernel
.
packet
[
3
],
T3
);
T7
=
_mm512_castpd_ps
(
_mm512_permutex_pd
(
_mm512_castps_pd
(
kernel
.
packet
[
3
]),
0x4E
));
T7
=
_mm512_mask_blend_ps
(
0xF0F0
,
T7
,
kernel
.
packet
[
7
]);
kernel
.
packet
[
0
]
=
T0
;
kernel
.
packet
[
1
]
=
T1
;
kernel
.
packet
[
2
]
=
T2
;
kernel
.
packet
[
3
]
=
T3
;
kernel
.
packet
[
4
]
=
T4
;
kernel
.
packet
[
5
]
=
T5
;
kernel
.
packet
[
6
]
=
T6
;
kernel
.
packet
[
7
]
=
T7
;
}
template
<>
EIGEN_ALWAYS_INLINE
void
trans8x8blocks
(
PacketBlock
<
Packet8d
,
8
>
&
kernel
)
{
ptranspose
(
kernel
);
}
/***
* Unrolls for transposed C stores
*/
template
<
typename
Scalar
>
class
trans
{
public
:
using
vec
=
typename
std
::
conditional
<
std
::
is_same
<
Scalar
,
float
>::
value
,
vecFullFloat
,
vecFullDouble
>::
type
;
using
vecHalf
=
typename
std
::
conditional
<
std
::
is_same
<
Scalar
,
float
>::
value
,
vecHalfFloat
,
vecFullDouble
>::
type
;
static
constexpr
int64_t
PacketSize
=
packet_traits
<
Scalar
>::
size
;
/***********************************
* Auxiliary Functions for:
* - storeC
***********************************
*/
/**
* aux_storeC
*
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
*
* (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC
*
**/
template
<
int64_t
endN
,
int64_t
counter
,
int64_t
unrollN
,
int64_t
packetIndexOffset
,
bool
remM
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
&&
endN
<=
PacketSize
)
>
aux_storeC
(
Scalar
*
C_arr
,
int64_t
LDC
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
remM_
=
0
)
{
constexpr
int64_t
counterReverse
=
endN
-
counter
;
constexpr
int64_t
startN
=
counterReverse
;
EIGEN_IF_CONSTEXPR
(
startN
<
EIGEN_AVX_MAX_NUM_ROW
)
{
EIGEN_IF_CONSTEXPR
(
remM
)
{
pstoreu
<
Scalar
>
(
C_arr
+
LDC
*
startN
,
padd
(
ploadu
<
vecHalf
>
((
const
Scalar
*
)
C_arr
+
LDC
*
startN
,
remMask
<
EIGEN_AVX_MAX_NUM_ROW
>
(
remM_
)),
preinterpret
<
vecHalf
>
(
zmm
.
packet
[
packetIndexOffset
+
(
unrollN
/
PacketSize
)
*
startN
]),
remMask
<
EIGEN_AVX_MAX_NUM_ROW
>
(
remM_
)),
remMask
<
EIGEN_AVX_MAX_NUM_ROW
>
(
remM_
));
}
else
{
pstoreu
<
Scalar
>
(
C_arr
+
LDC
*
startN
,
padd
(
ploadu
<
vecHalf
>
((
const
Scalar
*
)
C_arr
+
LDC
*
startN
),
preinterpret
<
vecHalf
>
(
zmm
.
packet
[
packetIndexOffset
+
(
unrollN
/
PacketSize
)
*
startN
])));
}
}
else
{
// This block is only needed for fp32 case
// Reinterpret as __m512 for _mm512_shuffle_f32x4
vecFullFloat
zmm2vecFullFloat
=
preinterpret
<
vecFullFloat
>
(
zmm
.
packet
[
packetIndexOffset
+
(
unrollN
/
PacketSize
)
*
(
startN
-
EIGEN_AVX_MAX_NUM_ROW
)]);
// Swap lower and upper half of avx register.
zmm
.
packet
[
packetIndexOffset
+
(
unrollN
/
PacketSize
)
*
(
startN
-
EIGEN_AVX_MAX_NUM_ROW
)]
=
preinterpret
<
vec
>
(
_mm512_shuffle_f32x4
(
zmm2vecFullFloat
,
zmm2vecFullFloat
,
0b01001110
));
EIGEN_IF_CONSTEXPR
(
remM
)
{
pstoreu
<
Scalar
>
(
C_arr
+
LDC
*
startN
,
padd
(
ploadu
<
vecHalf
>
((
const
Scalar
*
)
C_arr
+
LDC
*
startN
,
remMask
<
EIGEN_AVX_MAX_NUM_ROW
>
(
remM_
)),
preinterpret
<
vecHalf
>
(
zmm
.
packet
[
packetIndexOffset
+
(
unrollN
/
PacketSize
)
*
(
startN
-
EIGEN_AVX_MAX_NUM_ROW
)])),
remMask
<
EIGEN_AVX_MAX_NUM_ROW
>
(
remM_
));
}
else
{
pstoreu
<
Scalar
>
(
C_arr
+
LDC
*
startN
,
padd
(
ploadu
<
vecHalf
>
((
const
Scalar
*
)
C_arr
+
LDC
*
startN
),
preinterpret
<
vecHalf
>
(
zmm
.
packet
[
packetIndexOffset
+
(
unrollN
/
PacketSize
)
*
(
startN
-
EIGEN_AVX_MAX_NUM_ROW
)])));
}
}
aux_storeC
<
endN
,
counter
-
1
,
unrollN
,
packetIndexOffset
,
remM
>
(
C_arr
,
LDC
,
zmm
,
remM_
);
}
template
<
int64_t
endN
,
int64_t
counter
,
int64_t
unrollN
,
int64_t
packetIndexOffset
,
bool
remM
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<!
(
counter
>
0
&&
endN
<=
PacketSize
)
>
aux_storeC
(
Scalar
*
C_arr
,
int64_t
LDC
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
remM_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
C_arr
);
EIGEN_UNUSED_VARIABLE
(
LDC
);
EIGEN_UNUSED_VARIABLE
(
zmm
);
EIGEN_UNUSED_VARIABLE
(
remM_
);
}
template
<
int64_t
endN
,
int64_t
unrollN
,
int64_t
packetIndexOffset
,
bool
remM
>
static
EIGEN_ALWAYS_INLINE
void
storeC
(
Scalar
*
C_arr
,
int64_t
LDC
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
remM_
=
0
)
{
aux_storeC
<
endN
,
endN
,
unrollN
,
packetIndexOffset
,
remM
>
(
C_arr
,
LDC
,
zmm
,
remM_
);
}
/**
* Transposes LxunrollN row major block of matrices stored `EIGEN_AVX_MAX_NUM_ACC` zmm registers to
* "unrollN"xL ymm registers to be stored col-major into C.
*
* For 8x48, the 8x48 block (row-major) is stored in zmm as follows:
*
* ```
* row0: zmm0 zmm1 zmm2
* row1: zmm3 zmm4 zmm5
* .
* .
* row7: zmm21 zmm22 zmm23
*
* For 8x32, the 8x32 block (row-major) is stored in zmm as follows:
*
* row0: zmm0 zmm1
* row1: zmm2 zmm3
* .
* .
* row7: zmm14 zmm15
* ```
*
* In general we will have {1,2,3} groups of avx registers each of size
* `EIGEN_AVX_MAX_NUM_ROW`. packetIndexOffset is used to select which "block" of
* avx registers are being transposed.
*/
template
<
int64_t
unrollN
,
int64_t
packetIndexOffset
>
static
EIGEN_ALWAYS_INLINE
void
transpose
(
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
)
{
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
constexpr
int64_t
zmmStride
=
unrollN
/
PacketSize
;
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
r
;
r
.
packet
[
0
]
=
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
0
];
r
.
packet
[
1
]
=
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
1
];
r
.
packet
[
2
]
=
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
2
];
r
.
packet
[
3
]
=
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
3
];
r
.
packet
[
4
]
=
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
4
];
r
.
packet
[
5
]
=
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
5
];
r
.
packet
[
6
]
=
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
6
];
r
.
packet
[
7
]
=
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
7
];
trans8x8blocks
(
r
);
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
0
]
=
r
.
packet
[
0
];
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
1
]
=
r
.
packet
[
1
];
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
2
]
=
r
.
packet
[
2
];
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
3
]
=
r
.
packet
[
3
];
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
4
]
=
r
.
packet
[
4
];
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
5
]
=
r
.
packet
[
5
];
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
6
]
=
r
.
packet
[
6
];
zmm
.
packet
[
packetIndexOffset
+
zmmStride
*
7
]
=
r
.
packet
[
7
];
}
};
/**
* Unrolls for copyBToRowMajor
*
* Idea:
* 1) Load a block of right-hand sides to registers (using loadB).
* 2) Convert the block from column-major to row-major (transposeLxL)
* 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false).
*
* We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are
* used as temps for transposing.
*
* Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
*/
template
<
typename
Scalar
>
class
transB
{
public
:
using
vec
=
typename
std
::
conditional
<
std
::
is_same
<
Scalar
,
float
>::
value
,
vecFullFloat
,
vecFullDouble
>::
type
;
using
vecHalf
=
typename
std
::
conditional
<
std
::
is_same
<
Scalar
,
float
>::
value
,
vecHalfFloat
,
vecFullDouble
>::
type
;
static
constexpr
int64_t
PacketSize
=
packet_traits
<
Scalar
>::
size
;
/***********************************
* Auxiliary Functions for:
* - loadB
* - storeB
* - loadBBlock
* - storeBBlock
***********************************
*/
/**
* aux_loadB
*
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
**/
template
<
int64_t
endN
,
int64_t
counter
,
int64_t
packetIndexOffset
,
bool
remM
,
int64_t
remN_
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_loadB
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
constexpr
int64_t
counterReverse
=
endN
-
counter
;
constexpr
int64_t
startN
=
counterReverse
;
EIGEN_IF_CONSTEXPR
(
remM
)
{
ymm
.
packet
[
packetIndexOffset
+
startN
]
=
ploadu
<
vecHalf
>
((
const
Scalar
*
)
&
B_arr
[
startN
*
LDB
],
remMask
<
EIGEN_AVX_MAX_NUM_ROW
>
(
remM_
));
}
else
{
EIGEN_IF_CONSTEXPR
(
remN_
==
0
)
{
ymm
.
packet
[
packetIndexOffset
+
startN
]
=
ploadu
<
vecHalf
>
((
const
Scalar
*
)
&
B_arr
[
startN
*
LDB
]);
}
else
ymm
.
packet
[
packetIndexOffset
+
startN
]
=
ploadu
<
vecHalf
>
((
const
Scalar
*
)
&
B_arr
[
startN
*
LDB
],
remMask
<
EIGEN_AVX_MAX_NUM_ROW
>
(
remN_
));
}
aux_loadB
<
endN
,
counter
-
1
,
packetIndexOffset
,
remM
,
remN_
>
(
B_arr
,
LDB
,
ymm
,
remM_
);
}
template
<
int64_t
endN
,
int64_t
counter
,
int64_t
packetIndexOffset
,
bool
remM
,
int64_t
remN_
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_loadB
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
B_arr
);
EIGEN_UNUSED_VARIABLE
(
LDB
);
EIGEN_UNUSED_VARIABLE
(
ymm
);
EIGEN_UNUSED_VARIABLE
(
remM_
);
}
/**
* aux_storeB
*
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
**/
template
<
int64_t
endN
,
int64_t
counter
,
int64_t
packetIndexOffset
,
bool
remK
,
bool
remM
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_storeB
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
rem_
=
0
)
{
constexpr
int64_t
counterReverse
=
endN
-
counter
;
constexpr
int64_t
startN
=
counterReverse
;
EIGEN_IF_CONSTEXPR
(
remK
||
remM
)
{
pstoreu
<
Scalar
>
(
&
B_arr
[
startN
*
LDB
],
ymm
.
packet
[
packetIndexOffset
+
startN
],
remMask
<
EIGEN_AVX_MAX_NUM_ROW
>
(
rem_
));
}
else
{
pstoreu
<
Scalar
>
(
&
B_arr
[
startN
*
LDB
],
ymm
.
packet
[
packetIndexOffset
+
startN
]);
}
aux_storeB
<
endN
,
counter
-
1
,
packetIndexOffset
,
remK
,
remM
>
(
B_arr
,
LDB
,
ymm
,
rem_
);
}
template
<
int64_t
endN
,
int64_t
counter
,
int64_t
packetIndexOffset
,
bool
remK
,
bool
remM
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_storeB
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
B_arr
);
EIGEN_UNUSED_VARIABLE
(
LDB
);
EIGEN_UNUSED_VARIABLE
(
ymm
);
EIGEN_UNUSED_VARIABLE
(
rem_
);
}
/**
* aux_loadBBlock
*
* 1-D unroll
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/
template
<
int64_t
endN
,
int64_t
counter
,
bool
toTemp
,
bool
remM
,
int64_t
remN_
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_loadBBlock
(
Scalar
*
B_arr
,
int64_t
LDB
,
Scalar
*
B_temp
,
int64_t
LDB_
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
constexpr
int64_t
counterReverse
=
endN
-
counter
;
constexpr
int64_t
startN
=
counterReverse
;
transB
::
template
loadB
<
EIGEN_AVX_MAX_NUM_ROW
,
startN
,
false
,
(
toTemp
?
0
:
remN_
)
>
(
&
B_temp
[
startN
],
LDB_
,
ymm
);
aux_loadBBlock
<
endN
,
counter
-
EIGEN_AVX_MAX_NUM_ROW
,
toTemp
,
remM
,
remN_
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
}
template
<
int64_t
endN
,
int64_t
counter
,
bool
toTemp
,
bool
remM
,
int64_t
remN_
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_loadBBlock
(
Scalar
*
B_arr
,
int64_t
LDB
,
Scalar
*
B_temp
,
int64_t
LDB_
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
B_arr
);
EIGEN_UNUSED_VARIABLE
(
LDB
);
EIGEN_UNUSED_VARIABLE
(
B_temp
);
EIGEN_UNUSED_VARIABLE
(
LDB_
);
EIGEN_UNUSED_VARIABLE
(
ymm
);
EIGEN_UNUSED_VARIABLE
(
remM_
);
}
/**
* aux_storeBBlock
*
* 1-D unroll
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/
template
<
int64_t
endN
,
int64_t
counter
,
bool
toTemp
,
bool
remM
,
int64_t
remK_
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_storeBBlock
(
Scalar
*
B_arr
,
int64_t
LDB
,
Scalar
*
B_temp
,
int64_t
LDB_
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
constexpr
int64_t
counterReverse
=
endN
-
counter
;
constexpr
int64_t
startN
=
counterReverse
;
EIGEN_IF_CONSTEXPR
(
toTemp
)
{
transB
::
template
storeB
<
EIGEN_AVX_MAX_NUM_ROW
,
startN
,
remK_
!=
0
,
false
>
(
&
B_temp
[
startN
],
LDB_
,
ymm
,
remK_
);
}
else
{
transB
::
template
storeB
<
std
::
min
(
EIGEN_AVX_MAX_NUM_ROW
,
endN
),
startN
,
false
,
remM
>
(
&
B_arr
[
0
+
startN
*
LDB
],
LDB
,
ymm
,
remM_
);
}
aux_storeBBlock
<
endN
,
counter
-
EIGEN_AVX_MAX_NUM_ROW
,
toTemp
,
remM
,
remK_
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
}
template
<
int64_t
endN
,
int64_t
counter
,
bool
toTemp
,
bool
remM
,
int64_t
remK_
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_storeBBlock
(
Scalar
*
B_arr
,
int64_t
LDB
,
Scalar
*
B_temp
,
int64_t
LDB_
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
B_arr
);
EIGEN_UNUSED_VARIABLE
(
LDB
);
EIGEN_UNUSED_VARIABLE
(
B_temp
);
EIGEN_UNUSED_VARIABLE
(
LDB_
);
EIGEN_UNUSED_VARIABLE
(
ymm
);
EIGEN_UNUSED_VARIABLE
(
remM_
);
}
/********************************************************
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
template
<
int64_t
endN
,
int64_t
packetIndexOffset
,
bool
remM
,
int64_t
remN_
>
static
EIGEN_ALWAYS_INLINE
void
loadB
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
aux_loadB
<
endN
,
endN
,
packetIndexOffset
,
remM
,
remN_
>
(
B_arr
,
LDB
,
ymm
,
remM_
);
}
template
<
int64_t
endN
,
int64_t
packetIndexOffset
,
bool
remK
,
bool
remM
>
static
EIGEN_ALWAYS_INLINE
void
storeB
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
rem_
=
0
)
{
aux_storeB
<
endN
,
endN
,
packetIndexOffset
,
remK
,
remM
>
(
B_arr
,
LDB
,
ymm
,
rem_
);
}
template
<
int64_t
unrollN
,
bool
toTemp
,
bool
remM
,
int64_t
remN_
=
0
>
static
EIGEN_ALWAYS_INLINE
void
loadBBlock
(
Scalar
*
B_arr
,
int64_t
LDB
,
Scalar
*
B_temp
,
int64_t
LDB_
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
EIGEN_IF_CONSTEXPR
(
toTemp
)
{
transB
::
template
loadB
<
unrollN
,
0
,
remM
,
0
>
(
&
B_arr
[
0
],
LDB
,
ymm
,
remM_
);
}
else
{
aux_loadBBlock
<
unrollN
,
unrollN
,
toTemp
,
remM
,
remN_
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
}
}
template
<
int64_t
unrollN
,
bool
toTemp
,
bool
remM
,
int64_t
remK_
>
static
EIGEN_ALWAYS_INLINE
void
storeBBlock
(
Scalar
*
B_arr
,
int64_t
LDB
,
Scalar
*
B_temp
,
int64_t
LDB_
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
aux_storeBBlock
<
unrollN
,
unrollN
,
toTemp
,
remM
,
remK_
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
}
template
<
int64_t
packetIndexOffset
>
static
EIGEN_ALWAYS_INLINE
void
transposeLxL
(
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
)
{
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
PacketBlock
<
vecHalf
,
EIGEN_AVX_MAX_NUM_ROW
>
r
;
r
.
packet
[
0
]
=
ymm
.
packet
[
packetIndexOffset
+
0
];
r
.
packet
[
1
]
=
ymm
.
packet
[
packetIndexOffset
+
1
];
r
.
packet
[
2
]
=
ymm
.
packet
[
packetIndexOffset
+
2
];
r
.
packet
[
3
]
=
ymm
.
packet
[
packetIndexOffset
+
3
];
r
.
packet
[
4
]
=
ymm
.
packet
[
packetIndexOffset
+
4
];
r
.
packet
[
5
]
=
ymm
.
packet
[
packetIndexOffset
+
5
];
r
.
packet
[
6
]
=
ymm
.
packet
[
packetIndexOffset
+
6
];
r
.
packet
[
7
]
=
ymm
.
packet
[
packetIndexOffset
+
7
];
ptranspose
(
r
);
ymm
.
packet
[
packetIndexOffset
+
0
]
=
r
.
packet
[
0
];
ymm
.
packet
[
packetIndexOffset
+
1
]
=
r
.
packet
[
1
];
ymm
.
packet
[
packetIndexOffset
+
2
]
=
r
.
packet
[
2
];
ymm
.
packet
[
packetIndexOffset
+
3
]
=
r
.
packet
[
3
];
ymm
.
packet
[
packetIndexOffset
+
4
]
=
r
.
packet
[
4
];
ymm
.
packet
[
packetIndexOffset
+
5
]
=
r
.
packet
[
5
];
ymm
.
packet
[
packetIndexOffset
+
6
]
=
r
.
packet
[
6
];
ymm
.
packet
[
packetIndexOffset
+
7
]
=
r
.
packet
[
7
];
}
template
<
int64_t
unrollN
,
bool
toTemp
,
bool
remM
>
static
EIGEN_ALWAYS_INLINE
void
transB_kernel
(
Scalar
*
B_arr
,
int64_t
LDB
,
Scalar
*
B_temp
,
int64_t
LDB_
,
PacketBlock
<
vecHalf
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
ymm
,
int64_t
remM_
=
0
)
{
constexpr
int64_t
U3
=
PacketSize
*
3
;
constexpr
int64_t
U2
=
PacketSize
*
2
;
constexpr
int64_t
U1
=
PacketSize
*
1
;
/**
* Unrolls needed for each case:
* - AVX512 fp32 48 32 16 8 4 2 1
* - AVX512 fp64 24 16 8 4 2 1
*
* For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up.
*/
EIGEN_IF_CONSTEXPR
(
unrollN
==
U3
)
{
// load LxU3 B col major, transpose LxU3 row major
constexpr
int64_t
maxUBlock
=
std
::
min
(
3
*
EIGEN_AVX_MAX_NUM_ROW
,
U3
);
transB
::
template
loadBBlock
<
maxUBlock
,
toTemp
,
remM
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
transB
::
template
transposeLxL
<
0
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
transB
::
template
transposeLxL
<
1
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
transB
::
template
transposeLxL
<
2
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
transB
::
template
storeBBlock
<
maxUBlock
,
toTemp
,
remM
,
0
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
EIGEN_IF_CONSTEXPR
(
maxUBlock
<
U3
)
{
transB
::
template
loadBBlock
<
maxUBlock
,
toTemp
,
remM
>
(
&
B_arr
[
maxUBlock
*
LDB
],
LDB
,
&
B_temp
[
maxUBlock
],
LDB_
,
ymm
,
remM_
);
transB
::
template
transposeLxL
<
0
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
transB
::
template
transposeLxL
<
1
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
transB
::
template
transposeLxL
<
2
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
transB
::
template
storeBBlock
<
maxUBlock
,
toTemp
,
remM
,
0
>
(
&
B_arr
[
maxUBlock
*
LDB
],
LDB
,
&
B_temp
[
maxUBlock
],
LDB_
,
ymm
,
remM_
);
}
}
else
EIGEN_IF_CONSTEXPR
(
unrollN
==
U2
)
{
// load LxU2 B col major, transpose LxU2 row major
constexpr
int64_t
maxUBlock
=
std
::
min
(
3
*
EIGEN_AVX_MAX_NUM_ROW
,
U2
);
transB
::
template
loadBBlock
<
maxUBlock
,
toTemp
,
remM
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
transB
::
template
transposeLxL
<
0
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
transB
::
template
transposeLxL
<
1
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
EIGEN_IF_CONSTEXPR
(
maxUBlock
<
U2
)
transB
::
template
transposeLxL
<
2
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
transB
::
template
storeBBlock
<
maxUBlock
,
toTemp
,
remM
,
0
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
EIGEN_IF_CONSTEXPR
(
maxUBlock
<
U2
)
{
transB
::
template
loadBBlock
<
EIGEN_AVX_MAX_NUM_ROW
,
toTemp
,
remM
>
(
&
B_arr
[
maxUBlock
*
LDB
],
LDB
,
&
B_temp
[
maxUBlock
],
LDB_
,
ymm
,
remM_
);
transB
::
template
transposeLxL
<
0
>
(
ymm
);
transB
::
template
storeBBlock
<
EIGEN_AVX_MAX_NUM_ROW
,
toTemp
,
remM
,
0
>
(
&
B_arr
[
maxUBlock
*
LDB
],
LDB
,
&
B_temp
[
maxUBlock
],
LDB_
,
ymm
,
remM_
);
}
}
else
EIGEN_IF_CONSTEXPR
(
unrollN
==
U1
)
{
// load LxU1 B col major, transpose LxU1 row major
transB
::
template
loadBBlock
<
U1
,
toTemp
,
remM
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
transB
::
template
transposeLxL
<
0
>
(
ymm
);
EIGEN_IF_CONSTEXPR
(
EIGEN_AVX_MAX_NUM_ROW
<
U1
)
{
transB
::
template
transposeLxL
<
1
*
EIGEN_AVX_MAX_NUM_ROW
>
(
ymm
);
}
transB
::
template
storeBBlock
<
U1
,
toTemp
,
remM
,
0
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
}
else
EIGEN_IF_CONSTEXPR
(
unrollN
==
8
&&
U1
>
8
)
{
// load Lx4 B col major, transpose Lx4 row major
transB
::
template
loadBBlock
<
8
,
toTemp
,
remM
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
transB
::
template
transposeLxL
<
0
>
(
ymm
);
transB
::
template
storeBBlock
<
8
,
toTemp
,
remM
,
8
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
}
else
EIGEN_IF_CONSTEXPR
(
unrollN
==
4
&&
U1
>
4
)
{
// load Lx4 B col major, transpose Lx4 row major
transB
::
template
loadBBlock
<
4
,
toTemp
,
remM
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
transB
::
template
transposeLxL
<
0
>
(
ymm
);
transB
::
template
storeBBlock
<
4
,
toTemp
,
remM
,
4
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
}
else
EIGEN_IF_CONSTEXPR
(
unrollN
==
2
)
{
// load Lx2 B col major, transpose Lx2 row major
transB
::
template
loadBBlock
<
2
,
toTemp
,
remM
,
2
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
transB
::
template
transposeLxL
<
0
>
(
ymm
);
transB
::
template
storeBBlock
<
2
,
toTemp
,
remM
,
2
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
}
else
EIGEN_IF_CONSTEXPR
(
unrollN
==
1
)
{
// load Lx1 B col major, transpose Lx1 row major
transB
::
template
loadBBlock
<
1
,
toTemp
,
remM
,
1
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
transB
::
template
transposeLxL
<
0
>
(
ymm
);
transB
::
template
storeBBlock
<
1
,
toTemp
,
remM
,
1
>
(
B_arr
,
LDB
,
B_temp
,
LDB_
,
ymm
,
remM_
);
}
}
};
/**
* Unrolls for triSolveKernel
*
* Idea:
* 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS).
* 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix)
* stored in AInPacket (using triSolveMicroKernel).
* 3) Store final results (in avx registers) back into memory (using storeRHS).
*
* RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most
* EIGEN_AVX_MAX_NUM_ROW registers.
*/
template
<
typename
Scalar
>
class
trsm
{
public
:
using
vec
=
typename
std
::
conditional
<
std
::
is_same
<
Scalar
,
float
>::
value
,
vecFullFloat
,
vecFullDouble
>::
type
;
static
constexpr
int64_t
PacketSize
=
packet_traits
<
Scalar
>::
size
;
/***********************************
* Auxiliary Functions for:
* - loadRHS
* - storeRHS
* - divRHSByDiag
* - updateRHS
* - triSolveMicroKernel
************************************/
/**
* aux_loadRHS
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++)
**/
template
<
bool
isFWDSolve
,
int64_t
endM
,
int64_t
endK
,
int64_t
counter
,
bool
krem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_loadRHS
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
int64_t
rem
=
0
)
{
constexpr
int64_t
counterReverse
=
endM
*
endK
-
counter
;
constexpr
int64_t
startM
=
counterReverse
/
(
endK
);
constexpr
int64_t
startK
=
counterReverse
%
endK
;
constexpr
int64_t
packetIndex
=
startM
*
endK
+
startK
;
constexpr
int64_t
startM_
=
isFWDSolve
?
startM
:
-
startM
;
const
int64_t
rhsIndex
=
(
startK
*
PacketSize
)
+
startM_
*
LDB
;
EIGEN_IF_CONSTEXPR
(
krem
)
{
RHSInPacket
.
packet
[
packetIndex
]
=
ploadu
<
vec
>
(
&
B_arr
[
rhsIndex
],
remMask
<
PacketSize
>
(
rem
));
}
else
{
RHSInPacket
.
packet
[
packetIndex
]
=
ploadu
<
vec
>
(
&
B_arr
[
rhsIndex
]);
}
aux_loadRHS
<
isFWDSolve
,
endM
,
endK
,
counter
-
1
,
krem
>
(
B_arr
,
LDB
,
RHSInPacket
,
rem
);
}
template
<
bool
isFWDSolve
,
int64_t
endM
,
int64_t
endK
,
int64_t
counter
,
bool
krem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_loadRHS
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
int64_t
rem
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
B_arr
);
EIGEN_UNUSED_VARIABLE
(
LDB
);
EIGEN_UNUSED_VARIABLE
(
RHSInPacket
);
EIGEN_UNUSED_VARIABLE
(
rem
);
}
/**
* aux_storeRHS
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++)
**/
template
<
bool
isFWDSolve
,
int64_t
endM
,
int64_t
endK
,
int64_t
counter
,
bool
krem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_storeRHS
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
int64_t
rem
=
0
)
{
constexpr
int64_t
counterReverse
=
endM
*
endK
-
counter
;
constexpr
int64_t
startM
=
counterReverse
/
(
endK
);
constexpr
int64_t
startK
=
counterReverse
%
endK
;
constexpr
int64_t
packetIndex
=
startM
*
endK
+
startK
;
constexpr
int64_t
startM_
=
isFWDSolve
?
startM
:
-
startM
;
const
int64_t
rhsIndex
=
(
startK
*
PacketSize
)
+
startM_
*
LDB
;
EIGEN_IF_CONSTEXPR
(
krem
)
{
pstoreu
<
Scalar
>
(
&
B_arr
[
rhsIndex
],
RHSInPacket
.
packet
[
packetIndex
],
remMask
<
PacketSize
>
(
rem
));
}
else
{
pstoreu
<
Scalar
>
(
&
B_arr
[
rhsIndex
],
RHSInPacket
.
packet
[
packetIndex
]);
}
aux_storeRHS
<
isFWDSolve
,
endM
,
endK
,
counter
-
1
,
krem
>
(
B_arr
,
LDB
,
RHSInPacket
,
rem
);
}
template
<
bool
isFWDSolve
,
int64_t
endM
,
int64_t
endK
,
int64_t
counter
,
bool
krem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_storeRHS
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
int64_t
rem
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
B_arr
);
EIGEN_UNUSED_VARIABLE
(
LDB
);
EIGEN_UNUSED_VARIABLE
(
RHSInPacket
);
EIGEN_UNUSED_VARIABLE
(
rem
);
}
/**
* aux_divRHSByDiag
*
* currM may be -1, (currM >=0) in enable_if checks for this
*
* 1-D unroll
* for(startK = 0; startK < endK; startK++)
**/
template
<
int64_t
currM
,
int64_t
endK
,
int64_t
counter
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
&&
currM
>=
0
)
>
aux_divRHSByDiag
(
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
&
AInPacket
)
{
constexpr
int64_t
counterReverse
=
endK
-
counter
;
constexpr
int64_t
startK
=
counterReverse
;
constexpr
int64_t
packetIndex
=
currM
*
endK
+
startK
;
RHSInPacket
.
packet
[
packetIndex
]
=
pmul
(
AInPacket
.
packet
[
currM
],
RHSInPacket
.
packet
[
packetIndex
]);
aux_divRHSByDiag
<
currM
,
endK
,
counter
-
1
>
(
RHSInPacket
,
AInPacket
);
}
template
<
int64_t
currM
,
int64_t
endK
,
int64_t
counter
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<!
(
counter
>
0
&&
currM
>=
0
)
>
aux_divRHSByDiag
(
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
&
AInPacket
)
{
EIGEN_UNUSED_VARIABLE
(
RHSInPacket
);
EIGEN_UNUSED_VARIABLE
(
AInPacket
);
}
/**
* aux_updateRHS
*
* 2-D unroll
* for(startM = initM; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++)
**/
template
<
bool
isARowMajor
,
bool
isFWDSolve
,
bool
isUnitDiag
,
int64_t
initM
,
int64_t
endM
,
int64_t
endK
,
int64_t
counter
,
int64_t
currentM
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_updateRHS
(
Scalar
*
A_arr
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
&
AInPacket
)
{
constexpr
int64_t
counterReverse
=
(
endM
-
initM
)
*
endK
-
counter
;
constexpr
int64_t
startM
=
initM
+
counterReverse
/
(
endK
);
constexpr
int64_t
startK
=
counterReverse
%
endK
;
// For each row of A, first update all corresponding RHS
constexpr
int64_t
packetIndex
=
startM
*
endK
+
startK
;
EIGEN_IF_CONSTEXPR
(
currentM
>
0
)
{
RHSInPacket
.
packet
[
packetIndex
]
=
pnmadd
(
AInPacket
.
packet
[
startM
],
RHSInPacket
.
packet
[(
currentM
-
1
)
*
endK
+
startK
],
RHSInPacket
.
packet
[
packetIndex
]);
}
EIGEN_IF_CONSTEXPR
(
startK
==
endK
-
1
)
{
// Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
EIGEN_IF_CONSTEXPR
(
startM
==
currentM
&&
!
isUnitDiag
)
{
// If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
// This will be used in divRHSByDiag
EIGEN_IF_CONSTEXPR
(
isFWDSolve
)
AInPacket
.
packet
[
currentM
]
=
pset1
<
vec
>
(
Scalar
(
1
)
/
A_arr
[
idA
<
isARowMajor
>
(
currentM
,
currentM
,
LDA
)]);
else
AInPacket
.
packet
[
currentM
]
=
pset1
<
vec
>
(
Scalar
(
1
)
/
A_arr
[
idA
<
isARowMajor
>
(
-
currentM
,
-
currentM
,
LDA
)]);
}
else
{
// Broadcast next off diagonal element of A
EIGEN_IF_CONSTEXPR
(
isFWDSolve
)
AInPacket
.
packet
[
startM
]
=
pset1
<
vec
>
(
A_arr
[
idA
<
isARowMajor
>
(
startM
,
currentM
,
LDA
)]);
else
AInPacket
.
packet
[
startM
]
=
pset1
<
vec
>
(
A_arr
[
idA
<
isARowMajor
>
(
-
startM
,
-
currentM
,
LDA
)]);
}
}
aux_updateRHS
<
isARowMajor
,
isFWDSolve
,
isUnitDiag
,
initM
,
endM
,
endK
,
counter
-
1
,
currentM
>
(
A_arr
,
LDA
,
RHSInPacket
,
AInPacket
);
}
template
<
bool
isARowMajor
,
bool
isFWDSolve
,
bool
isUnitDiag
,
int64_t
initM
,
int64_t
endM
,
int64_t
endK
,
int64_t
counter
,
int64_t
currentM
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_updateRHS
(
Scalar
*
A_arr
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
&
AInPacket
)
{
EIGEN_UNUSED_VARIABLE
(
A_arr
);
EIGEN_UNUSED_VARIABLE
(
LDA
);
EIGEN_UNUSED_VARIABLE
(
RHSInPacket
);
EIGEN_UNUSED_VARIABLE
(
AInPacket
);
}
/**
* aux_triSolverMicroKernel
*
* 1-D unroll
* for(startM = 0; startM < endM; startM++)
**/
template
<
bool
isARowMajor
,
bool
isFWDSolve
,
bool
isUnitDiag
,
int64_t
endM
,
int64_t
counter
,
int64_t
numK
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_triSolveMicroKernel
(
Scalar
*
A_arr
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
&
AInPacket
)
{
constexpr
int64_t
counterReverse
=
endM
-
counter
;
constexpr
int64_t
startM
=
counterReverse
;
constexpr
int64_t
currentM
=
startM
;
// Divides the right-hand side in row startM, by digonal value of A
// broadcasted to AInPacket.packet[startM-1] in the previous iteration.
//
// Without "if constexpr" the compiler instantiates the case <-1, numK>
// this is handled with enable_if to prevent out-of-bound warnings
// from the compiler
EIGEN_IF_CONSTEXPR
(
!
isUnitDiag
&&
startM
>
0
)
trsm
::
template
divRHSByDiag
<
startM
-
1
,
numK
>
(
RHSInPacket
,
AInPacket
);
// After division, the rhs corresponding to subsequent rows of A can be partially updated
// We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
// to be used in the next iteration.
trsm
::
template
updateRHS
<
isARowMajor
,
isFWDSolve
,
isUnitDiag
,
startM
,
endM
,
numK
,
currentM
>
(
A_arr
,
LDA
,
RHSInPacket
,
AInPacket
);
// Handle division for the RHS corresponding to the final row of A.
EIGEN_IF_CONSTEXPR
(
!
isUnitDiag
&&
startM
==
endM
-
1
)
trsm
::
template
divRHSByDiag
<
startM
,
numK
>
(
RHSInPacket
,
AInPacket
);
aux_triSolveMicroKernel
<
isARowMajor
,
isFWDSolve
,
isUnitDiag
,
endM
,
counter
-
1
,
numK
>
(
A_arr
,
LDA
,
RHSInPacket
,
AInPacket
);
}
template
<
bool
isARowMajor
,
bool
isFWDSolve
,
bool
isUnitDiag
,
int64_t
endM
,
int64_t
counter
,
int64_t
numK
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_triSolveMicroKernel
(
Scalar
*
A_arr
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
&
AInPacket
)
{
EIGEN_UNUSED_VARIABLE
(
A_arr
);
EIGEN_UNUSED_VARIABLE
(
LDA
);
EIGEN_UNUSED_VARIABLE
(
RHSInPacket
);
EIGEN_UNUSED_VARIABLE
(
AInPacket
);
}
/********************************************************
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
/**
* Load endMxendK block of B to RHSInPacket
* Masked loads are used for cases where endK is not a multiple of PacketSize
*/
template
<
bool
isFWDSolve
,
int64_t
endM
,
int64_t
endK
,
bool
krem
=
false
>
static
EIGEN_ALWAYS_INLINE
void
loadRHS
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
int64_t
rem
=
0
)
{
aux_loadRHS
<
isFWDSolve
,
endM
,
endK
,
endM
*
endK
,
krem
>
(
B_arr
,
LDB
,
RHSInPacket
,
rem
);
}
/**
* Load endMxendK block of B to RHSInPacket
* Masked loads are used for cases where endK is not a multiple of PacketSize
*/
template
<
bool
isFWDSolve
,
int64_t
endM
,
int64_t
endK
,
bool
krem
=
false
>
static
EIGEN_ALWAYS_INLINE
void
storeRHS
(
Scalar
*
B_arr
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
int64_t
rem
=
0
)
{
aux_storeRHS
<
isFWDSolve
,
endM
,
endK
,
endM
*
endK
,
krem
>
(
B_arr
,
LDB
,
RHSInPacket
,
rem
);
}
/**
* Only used if Triangular matrix has non-unit diagonal values
*/
template
<
int64_t
currM
,
int64_t
endK
>
static
EIGEN_ALWAYS_INLINE
void
divRHSByDiag
(
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
&
AInPacket
)
{
aux_divRHSByDiag
<
currM
,
endK
,
endK
>
(
RHSInPacket
,
AInPacket
);
}
/**
* Update right-hand sides (stored in avx registers)
* Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
**/
template
<
bool
isARowMajor
,
bool
isFWDSolve
,
bool
isUnitDiag
,
int64_t
startM
,
int64_t
endM
,
int64_t
endK
,
int64_t
currentM
>
static
EIGEN_ALWAYS_INLINE
void
updateRHS
(
Scalar
*
A_arr
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
&
AInPacket
)
{
aux_updateRHS
<
isARowMajor
,
isFWDSolve
,
isUnitDiag
,
startM
,
endM
,
endK
,
(
endM
-
startM
)
*
endK
,
currentM
>
(
A_arr
,
LDA
,
RHSInPacket
,
AInPacket
);
}
/**
* endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW
* numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3.
* isFWDSolve: true => forward substitution, false => backwards substitution
* isUnitDiag: true => triangular matrix has unit diagonal.
*/
template
<
bool
isARowMajor
,
bool
isFWDSolve
,
bool
isUnitDiag
,
int64_t
endM
,
int64_t
numK
>
static
EIGEN_ALWAYS_INLINE
void
triSolveMicroKernel
(
Scalar
*
A_arr
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ACC
>
&
RHSInPacket
,
PacketBlock
<
vec
,
EIGEN_AVX_MAX_NUM_ROW
>
&
AInPacket
)
{
static_assert
(
numK
>=
1
&&
numK
<=
3
,
"numK out of range"
);
aux_triSolveMicroKernel
<
isARowMajor
,
isFWDSolve
,
isUnitDiag
,
endM
,
endM
,
numK
>
(
A_arr
,
LDA
,
RHSInPacket
,
AInPacket
);
}
};
/**
* Unrolls for gemm kernel
*
* isAdd: true => C += A*B, false => C -= A*B
*/
template
<
typename
Scalar
,
bool
isAdd
>
class
gemm
{
public
:
using
vec
=
typename
std
::
conditional
<
std
::
is_same
<
Scalar
,
float
>::
value
,
vecFullFloat
,
vecFullDouble
>::
type
;
static
constexpr
int64_t
PacketSize
=
packet_traits
<
Scalar
>::
size
;
/***********************************
* Auxiliary Functions for:
* - setzero
* - updateC
* - storeC
* - startLoadB
* - triSolveMicroKernel
************************************/
/**
* aux_setzero
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
**/
template
<
int64_t
endM
,
int64_t
endN
,
int64_t
counter
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_setzero
(
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
)
{
constexpr
int64_t
counterReverse
=
endM
*
endN
-
counter
;
constexpr
int64_t
startM
=
counterReverse
/
(
endN
);
constexpr
int64_t
startN
=
counterReverse
%
endN
;
zmm
.
packet
[
startN
*
endM
+
startM
]
=
pzero
(
zmm
.
packet
[
startN
*
endM
+
startM
]);
aux_setzero
<
endM
,
endN
,
counter
-
1
>
(
zmm
);
}
template
<
int64_t
endM
,
int64_t
endN
,
int64_t
counter
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_setzero
(
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
)
{
EIGEN_UNUSED_VARIABLE
(
zmm
);
}
/**
* aux_updateC
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
**/
template
<
int64_t
endM
,
int64_t
endN
,
int64_t
counter
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_updateC
(
Scalar
*
C_arr
,
int64_t
LDC
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
constexpr
int64_t
counterReverse
=
endM
*
endN
-
counter
;
constexpr
int64_t
startM
=
counterReverse
/
(
endN
);
constexpr
int64_t
startN
=
counterReverse
%
endN
;
EIGEN_IF_CONSTEXPR
(
rem
)
zmm
.
packet
[
startN
*
endM
+
startM
]
=
padd
(
ploadu
<
vec
>
(
&
C_arr
[(
startN
)
*
LDC
+
startM
*
PacketSize
],
remMask
<
PacketSize
>
(
rem_
)),
zmm
.
packet
[
startN
*
endM
+
startM
],
remMask
<
PacketSize
>
(
rem_
));
else
zmm
.
packet
[
startN
*
endM
+
startM
]
=
padd
(
ploadu
<
vec
>
(
&
C_arr
[(
startN
)
*
LDC
+
startM
*
PacketSize
]),
zmm
.
packet
[
startN
*
endM
+
startM
]);
aux_updateC
<
endM
,
endN
,
counter
-
1
,
rem
>
(
C_arr
,
LDC
,
zmm
,
rem_
);
}
template
<
int64_t
endM
,
int64_t
endN
,
int64_t
counter
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_updateC
(
Scalar
*
C_arr
,
int64_t
LDC
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
C_arr
);
EIGEN_UNUSED_VARIABLE
(
LDC
);
EIGEN_UNUSED_VARIABLE
(
zmm
);
EIGEN_UNUSED_VARIABLE
(
rem_
);
}
/**
* aux_storeC
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
**/
template
<
int64_t
endM
,
int64_t
endN
,
int64_t
counter
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_storeC
(
Scalar
*
C_arr
,
int64_t
LDC
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
constexpr
int64_t
counterReverse
=
endM
*
endN
-
counter
;
constexpr
int64_t
startM
=
counterReverse
/
(
endN
);
constexpr
int64_t
startN
=
counterReverse
%
endN
;
EIGEN_IF_CONSTEXPR
(
rem
)
pstoreu
<
Scalar
>
(
&
C_arr
[(
startN
)
*
LDC
+
startM
*
PacketSize
],
zmm
.
packet
[
startN
*
endM
+
startM
],
remMask
<
PacketSize
>
(
rem_
));
else
pstoreu
<
Scalar
>
(
&
C_arr
[(
startN
)
*
LDC
+
startM
*
PacketSize
],
zmm
.
packet
[
startN
*
endM
+
startM
]);
aux_storeC
<
endM
,
endN
,
counter
-
1
,
rem
>
(
C_arr
,
LDC
,
zmm
,
rem_
);
}
template
<
int64_t
endM
,
int64_t
endN
,
int64_t
counter
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_storeC
(
Scalar
*
C_arr
,
int64_t
LDC
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
C_arr
);
EIGEN_UNUSED_VARIABLE
(
LDC
);
EIGEN_UNUSED_VARIABLE
(
zmm
);
EIGEN_UNUSED_VARIABLE
(
rem_
);
}
/**
* aux_startLoadB
*
* 1-D unroll
* for(startL = 0; startL < endL; startL++)
**/
template
<
int64_t
unrollM
,
int64_t
unrollN
,
int64_t
endL
,
int64_t
counter
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_startLoadB
(
Scalar
*
B_t
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
constexpr
int64_t
counterReverse
=
endL
-
counter
;
constexpr
int64_t
startL
=
counterReverse
;
EIGEN_IF_CONSTEXPR
(
rem
)
zmm
.
packet
[
unrollM
*
unrollN
+
startL
]
=
ploadu
<
vec
>
(
&
B_t
[(
startL
/
unrollM
)
*
LDB
+
(
startL
%
unrollM
)
*
PacketSize
],
remMask
<
PacketSize
>
(
rem_
));
else
zmm
.
packet
[
unrollM
*
unrollN
+
startL
]
=
ploadu
<
vec
>
(
&
B_t
[(
startL
/
unrollM
)
*
LDB
+
(
startL
%
unrollM
)
*
PacketSize
]);
aux_startLoadB
<
unrollM
,
unrollN
,
endL
,
counter
-
1
,
rem
>
(
B_t
,
LDB
,
zmm
,
rem_
);
}
template
<
int64_t
unrollM
,
int64_t
unrollN
,
int64_t
endL
,
int64_t
counter
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_startLoadB
(
Scalar
*
B_t
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
B_t
);
EIGEN_UNUSED_VARIABLE
(
LDB
);
EIGEN_UNUSED_VARIABLE
(
zmm
);
EIGEN_UNUSED_VARIABLE
(
rem_
);
}
/**
* aux_startBCastA
*
* 1-D unroll
* for(startB = 0; startB < endB; startB++)
**/
template
<
bool
isARowMajor
,
int64_t
unrollM
,
int64_t
unrollN
,
int64_t
endB
,
int64_t
counter
,
int64_t
numLoad
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_startBCastA
(
Scalar
*
A_t
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
)
{
constexpr
int64_t
counterReverse
=
endB
-
counter
;
constexpr
int64_t
startB
=
counterReverse
;
zmm
.
packet
[
unrollM
*
unrollN
+
numLoad
+
startB
]
=
pload1
<
vec
>
(
&
A_t
[
idA
<
isARowMajor
>
(
startB
,
0
,
LDA
)]);
aux_startBCastA
<
isARowMajor
,
unrollM
,
unrollN
,
endB
,
counter
-
1
,
numLoad
>
(
A_t
,
LDA
,
zmm
);
}
template
<
bool
isARowMajor
,
int64_t
unrollM
,
int64_t
unrollN
,
int64_t
endB
,
int64_t
counter
,
int64_t
numLoad
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_startBCastA
(
Scalar
*
A_t
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
)
{
EIGEN_UNUSED_VARIABLE
(
A_t
);
EIGEN_UNUSED_VARIABLE
(
LDA
);
EIGEN_UNUSED_VARIABLE
(
zmm
);
}
/**
* aux_loadB
* currK: current K
*
* 1-D unroll
* for(startM = 0; startM < endM; startM++)
**/
template
<
int64_t
endM
,
int64_t
counter
,
int64_t
unrollN
,
int64_t
currK
,
int64_t
unrollK
,
int64_t
numLoad
,
int64_t
numBCast
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_loadB
(
Scalar
*
B_t
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
if
((
numLoad
/
endM
+
currK
<
unrollK
))
{
constexpr
int64_t
counterReverse
=
endM
-
counter
;
constexpr
int64_t
startM
=
counterReverse
;
EIGEN_IF_CONSTEXPR
(
rem
)
{
zmm
.
packet
[
endM
*
unrollN
+
(
startM
+
currK
*
endM
)
%
numLoad
]
=
ploadu
<
vec
>
(
&
B_t
[(
numLoad
/
endM
+
currK
)
*
LDB
+
startM
*
PacketSize
],
remMask
<
PacketSize
>
(
rem_
));
}
else
{
zmm
.
packet
[
endM
*
unrollN
+
(
startM
+
currK
*
endM
)
%
numLoad
]
=
ploadu
<
vec
>
(
&
B_t
[(
numLoad
/
endM
+
currK
)
*
LDB
+
startM
*
PacketSize
]);
}
aux_loadB
<
endM
,
counter
-
1
,
unrollN
,
currK
,
unrollK
,
numLoad
,
numBCast
,
rem
>
(
B_t
,
LDB
,
zmm
,
rem_
);
}
}
template
<
int64_t
endM
,
int64_t
counter
,
int64_t
unrollN
,
int64_t
currK
,
int64_t
unrollK
,
int64_t
numLoad
,
int64_t
numBCast
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_loadB
(
Scalar
*
B_t
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
B_t
);
EIGEN_UNUSED_VARIABLE
(
LDB
);
EIGEN_UNUSED_VARIABLE
(
zmm
);
EIGEN_UNUSED_VARIABLE
(
rem_
);
}
/**
* aux_microKernel
*
* 3-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
* for(startK = 0; startK < endK; startK++)
**/
template
<
bool
isARowMajor
,
int64_t
endM
,
int64_t
endN
,
int64_t
endK
,
int64_t
counter
,
int64_t
numLoad
,
int64_t
numBCast
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
>
0
)
>
aux_microKernel
(
Scalar
*
B_t
,
Scalar
*
A_t
,
int64_t
LDB
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
constexpr
int64_t
counterReverse
=
endM
*
endN
*
endK
-
counter
;
constexpr
int
startK
=
counterReverse
/
(
endM
*
endN
);
constexpr
int
startN
=
(
counterReverse
/
(
endM
))
%
endN
;
constexpr
int
startM
=
counterReverse
%
endM
;
EIGEN_IF_CONSTEXPR
(
startK
==
0
&&
startM
==
0
&&
startN
==
0
)
{
gemm
::
template
startLoadB
<
endM
,
endN
,
numLoad
,
rem
>
(
B_t
,
LDB
,
zmm
,
rem_
);
gemm
::
template
startBCastA
<
isARowMajor
,
endM
,
endN
,
numBCast
,
numLoad
>
(
A_t
,
LDA
,
zmm
);
}
{
// Interleave FMA and Bcast
EIGEN_IF_CONSTEXPR
(
isAdd
)
{
zmm
.
packet
[
startN
*
endM
+
startM
]
=
pmadd
(
zmm
.
packet
[
endM
*
endN
+
numLoad
+
(
startN
+
startK
*
endN
)
%
numBCast
],
zmm
.
packet
[
endM
*
endN
+
(
startM
+
startK
*
endM
)
%
numLoad
],
zmm
.
packet
[
startN
*
endM
+
startM
]);
}
else
{
zmm
.
packet
[
startN
*
endM
+
startM
]
=
pnmadd
(
zmm
.
packet
[
endM
*
endN
+
numLoad
+
(
startN
+
startK
*
endN
)
%
numBCast
],
zmm
.
packet
[
endM
*
endN
+
(
startM
+
startK
*
endM
)
%
numLoad
],
zmm
.
packet
[
startN
*
endM
+
startM
]);
}
// Bcast
EIGEN_IF_CONSTEXPR
(
startM
==
endM
-
1
&&
(
numBCast
+
startN
+
startK
*
endN
<
endK
*
endN
))
{
zmm
.
packet
[
endM
*
endN
+
numLoad
+
(
startN
+
startK
*
endN
)
%
numBCast
]
=
pload1
<
vec
>
(
&
A_t
[
idA
<
isARowMajor
>
(
(
numBCast
+
startN
+
startK
*
endN
)
%
endN
,
(
numBCast
+
startN
+
startK
*
endN
)
/
endN
,
LDA
)]);
}
}
// We have updated all accumulators, time to load next set of B's
EIGEN_IF_CONSTEXPR
((
startN
==
endN
-
1
)
&&
(
startM
==
endM
-
1
))
{
gemm
::
template
loadB
<
endM
,
endN
,
startK
,
endK
,
numLoad
,
numBCast
,
rem
>
(
B_t
,
LDB
,
zmm
,
rem_
);
}
aux_microKernel
<
isARowMajor
,
endM
,
endN
,
endK
,
counter
-
1
,
numLoad
,
numBCast
,
rem
>
(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
rem_
);
}
template
<
bool
isARowMajor
,
int64_t
endM
,
int64_t
endN
,
int64_t
endK
,
int64_t
counter
,
int64_t
numLoad
,
int64_t
numBCast
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
std
::
enable_if_t
<
(
counter
<=
0
)
>
aux_microKernel
(
Scalar
*
B_t
,
Scalar
*
A_t
,
int64_t
LDB
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
B_t
);
EIGEN_UNUSED_VARIABLE
(
A_t
);
EIGEN_UNUSED_VARIABLE
(
LDB
);
EIGEN_UNUSED_VARIABLE
(
LDA
);
EIGEN_UNUSED_VARIABLE
(
zmm
);
EIGEN_UNUSED_VARIABLE
(
rem_
);
}
/********************************************************
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
template
<
int64_t
endM
,
int64_t
endN
>
static
EIGEN_ALWAYS_INLINE
void
setzero
(
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
)
{
aux_setzero
<
endM
,
endN
,
endM
*
endN
>
(
zmm
);
}
/**
* Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
*/
template
<
int64_t
endM
,
int64_t
endN
,
bool
rem
=
false
>
static
EIGEN_ALWAYS_INLINE
void
updateC
(
Scalar
*
C_arr
,
int64_t
LDC
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
aux_updateC
<
endM
,
endN
,
endM
*
endN
,
rem
>
(
C_arr
,
LDC
,
zmm
,
rem_
);
}
template
<
int64_t
endM
,
int64_t
endN
,
bool
rem
=
false
>
static
EIGEN_ALWAYS_INLINE
void
storeC
(
Scalar
*
C_arr
,
int64_t
LDC
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
aux_storeC
<
endM
,
endN
,
endM
*
endN
,
rem
>
(
C_arr
,
LDC
,
zmm
,
rem_
);
}
/**
* Use numLoad registers for loading B at start of microKernel
*/
template
<
int64_t
unrollM
,
int64_t
unrollN
,
int64_t
endL
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
void
startLoadB
(
Scalar
*
B_t
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
aux_startLoadB
<
unrollM
,
unrollN
,
endL
,
endL
,
rem
>
(
B_t
,
LDB
,
zmm
,
rem_
);
}
/**
* Use numBCast registers for broadcasting A at start of microKernel
*/
template
<
bool
isARowMajor
,
int64_t
unrollM
,
int64_t
unrollN
,
int64_t
endB
,
int64_t
numLoad
>
static
EIGEN_ALWAYS_INLINE
void
startBCastA
(
Scalar
*
A_t
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
)
{
aux_startBCastA
<
isARowMajor
,
unrollM
,
unrollN
,
endB
,
endB
,
numLoad
>
(
A_t
,
LDA
,
zmm
);
}
/**
* Loads next set of B into vector registers between each K unroll.
*/
template
<
int64_t
endM
,
int64_t
unrollN
,
int64_t
currK
,
int64_t
unrollK
,
int64_t
numLoad
,
int64_t
numBCast
,
bool
rem
>
static
EIGEN_ALWAYS_INLINE
void
loadB
(
Scalar
*
B_t
,
int64_t
LDB
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
aux_loadB
<
endM
,
endM
,
unrollN
,
currK
,
unrollK
,
numLoad
,
numBCast
,
rem
>
(
B_t
,
LDB
,
zmm
,
rem_
);
}
/**
* Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B.
* A matrix can be row/col-major. B matrix is assumed row-major.
*
* isARowMajor: is A row major
* endM: Number registers per row
* endN: Number of rows
* endK: Loop unroll for K.
* numLoad: Number of registers for loading B.
* numBCast: Number of registers for broadcasting A.
*
* Ex: microkernel<isARowMajor,0,3,0,4,0,4,6,2>: 8x48 unroll (24 accumulators), k unrolled 4 times,
* 6 register for loading B, 2 for broadcasting A.
*
* Note: Ideally the microkernel should not have any register spilling.
* The avx instruction counts should be:
* - endK*endN vbroadcasts{s,d}
* - endK*endM vmovup{s,d}
* - endK*endN*endM FMAs
*
* From testing, there are no register spills with clang. There are register spills with GNU, which
* causes a performance hit.
*/
template
<
bool
isARowMajor
,
int64_t
endM
,
int64_t
endN
,
int64_t
endK
,
int64_t
numLoad
,
int64_t
numBCast
,
bool
rem
=
false
>
static
EIGEN_ALWAYS_INLINE
void
microKernel
(
Scalar
*
B_t
,
Scalar
*
A_t
,
int64_t
LDB
,
int64_t
LDA
,
PacketBlock
<
vec
,
EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
>
&
zmm
,
int64_t
rem_
=
0
)
{
EIGEN_UNUSED_VARIABLE
(
rem_
);
aux_microKernel
<
isARowMajor
,
endM
,
endN
,
endK
,
endM
*
endN
*
endK
,
numLoad
,
numBCast
,
rem
>
(
B_t
,
A_t
,
LDB
,
LDA
,
zmm
,
rem_
);
}
};
}
// namespace unrolls
#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
eigen-master/Eigen/src/Core/arch/AVX512/TypeCasting.h
0 → 100644
View file @
266d4fd9
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TYPE_CASTING_AVX512_H
#define EIGEN_TYPE_CASTING_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace
Eigen
{
namespace
internal
{
template
<
>
struct
type_casting_traits
<
float
,
bool
>
:
vectorized_type_casting_traits
<
float
,
bool
>
{};
template
<
>
struct
type_casting_traits
<
bool
,
float
>
:
vectorized_type_casting_traits
<
bool
,
float
>
{};
template
<
>
struct
type_casting_traits
<
float
,
int
>
:
vectorized_type_casting_traits
<
float
,
int
>
{};
template
<
>
struct
type_casting_traits
<
int
,
float
>
:
vectorized_type_casting_traits
<
int
,
float
>
{};
template
<
>
struct
type_casting_traits
<
float
,
double
>
:
vectorized_type_casting_traits
<
float
,
double
>
{};
template
<
>
struct
type_casting_traits
<
double
,
float
>
:
vectorized_type_casting_traits
<
double
,
float
>
{};
template
<
>
struct
type_casting_traits
<
double
,
int
>
:
vectorized_type_casting_traits
<
double
,
int
>
{};
template
<
>
struct
type_casting_traits
<
int
,
double
>
:
vectorized_type_casting_traits
<
int
,
double
>
{};
template
<
>
struct
type_casting_traits
<
double
,
int64_t
>
:
vectorized_type_casting_traits
<
double
,
int64_t
>
{};
template
<
>
struct
type_casting_traits
<
int64_t
,
double
>
:
vectorized_type_casting_traits
<
int64_t
,
double
>
{};
template
<
>
struct
type_casting_traits
<
half
,
float
>
:
vectorized_type_casting_traits
<
half
,
float
>
{};
template
<
>
struct
type_casting_traits
<
float
,
half
>
:
vectorized_type_casting_traits
<
float
,
half
>
{};
template
<
>
struct
type_casting_traits
<
bfloat16
,
float
>
:
vectorized_type_casting_traits
<
bfloat16
,
float
>
{};
template
<
>
struct
type_casting_traits
<
float
,
bfloat16
>
:
vectorized_type_casting_traits
<
float
,
bfloat16
>
{};
template
<
>
EIGEN_STRONG_INLINE
Packet16b
pcast
<
Packet16f
,
Packet16b
>
(
const
Packet16f
&
a
)
{
__mmask16
mask
=
_mm512_cmpneq_ps_mask
(
a
,
pzero
(
a
));
return
_mm512_maskz_cvtepi32_epi8
(
mask
,
_mm512_set1_epi32
(
1
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcast
<
Packet16b
,
Packet16f
>
(
const
Packet16b
&
a
)
{
return
_mm512_cvtepi32_ps
(
_mm512_and_si512
(
_mm512_cvtepi8_epi32
(
a
),
_mm512_set1_epi32
(
1
)));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pcast
<
Packet16f
,
Packet16i
>
(
const
Packet16f
&
a
)
{
return
_mm512_cvttps_epi32
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pcast
<
Packet16f
,
Packet8d
>
(
const
Packet16f
&
a
)
{
return
_mm512_cvtps_pd
(
_mm512_castps512_ps256
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pcast
<
Packet8f
,
Packet8d
>
(
const
Packet8f
&
a
)
{
return
_mm512_cvtps_pd
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
pcast
<
Packet8d
,
Packet8l
>
(
const
Packet8d
&
a
)
{
#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
return
_mm512_cvttpd_epi64
(
a
);
#else
constexpr
int
kTotalBits
=
sizeof
(
double
)
*
CHAR_BIT
,
kMantissaBits
=
std
::
numeric_limits
<
double
>::
digits
-
1
,
kExponentBits
=
kTotalBits
-
kMantissaBits
-
1
,
kBias
=
(
1
<<
(
kExponentBits
-
1
))
-
1
;
const
__m512i
cst_one
=
_mm512_set1_epi64
(
1
);
const
__m512i
cst_total_bits
=
_mm512_set1_epi64
(
kTotalBits
);
const
__m512i
cst_bias
=
_mm512_set1_epi64
(
kBias
);
__m512i
a_bits
=
_mm512_castpd_si512
(
a
);
// shift left by 1 to clear the sign bit, and shift right by kMantissaBits + 1 to recover biased exponent
__m512i
biased_e
=
_mm512_srli_epi64
(
_mm512_slli_epi64
(
a_bits
,
1
),
kMantissaBits
+
1
);
__m512i
e
=
_mm512_sub_epi64
(
biased_e
,
cst_bias
);
// shift to the left by kExponentBits + 1 to clear the sign and exponent bits
__m512i
shifted_mantissa
=
_mm512_slli_epi64
(
a_bits
,
kExponentBits
+
1
);
// shift to the right by kTotalBits - e to convert the significand to an integer
__m512i
result_significand
=
_mm512_srlv_epi64
(
shifted_mantissa
,
_mm512_sub_epi64
(
cst_total_bits
,
e
));
// add the implied bit
__m512i
result_exponent
=
_mm512_sllv_epi64
(
cst_one
,
e
);
// e <= 0 is interpreted as a large positive shift (2's complement), which also conveniently results in zero
__m512i
result
=
_mm512_add_epi64
(
result_significand
,
result_exponent
);
// handle negative arguments
__mmask8
sign_mask
=
_mm512_cmplt_epi64_mask
(
a_bits
,
_mm512_setzero_si512
());
result
=
_mm512_mask_sub_epi64
(
result
,
sign_mask
,
_mm512_setzero_si512
(),
result
);
return
result
;
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcast
<
Packet16i
,
Packet16f
>
(
const
Packet16i
&
a
)
{
return
_mm512_cvtepi32_ps
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pcast
<
Packet16i
,
Packet8d
>
(
const
Packet16i
&
a
)
{
return
_mm512_cvtepi32_pd
(
_mm512_castsi512_si256
(
a
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pcast
<
Packet8i
,
Packet8d
>
(
const
Packet8i
&
a
)
{
return
_mm512_cvtepi32_pd
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
pcast
<
Packet8l
,
Packet8d
>
(
const
Packet8l
&
a
)
{
#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
return
_mm512_cvtepi64_pd
(
a
);
#else
EIGEN_ALIGN64
int64_t
aux
[
8
];
pstore
(
aux
,
a
);
return
_mm512_set_pd
(
static_cast
<
double
>
(
aux
[
7
]),
static_cast
<
double
>
(
aux
[
6
]),
static_cast
<
double
>
(
aux
[
5
]),
static_cast
<
double
>
(
aux
[
4
]),
static_cast
<
double
>
(
aux
[
3
]),
static_cast
<
double
>
(
aux
[
2
]),
static_cast
<
double
>
(
aux
[
1
]),
static_cast
<
double
>
(
aux
[
0
]));
#endif
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcast
<
Packet8d
,
Packet16f
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
cat256
(
_mm512_cvtpd_ps
(
a
),
_mm512_cvtpd_ps
(
b
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
pcast
<
Packet8d
,
Packet16i
>
(
const
Packet8d
&
a
,
const
Packet8d
&
b
)
{
return
cat256i
(
_mm512_cvttpd_epi32
(
a
),
_mm512_cvttpd_epi32
(
b
));
}
template
<
>
EIGEN_STRONG_INLINE
Packet8i
pcast
<
Packet8d
,
Packet8i
>
(
const
Packet8d
&
a
)
{
return
_mm512_cvtpd_epi32
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8f
pcast
<
Packet8d
,
Packet8f
>
(
const
Packet8d
&
a
)
{
return
_mm512_cvtpd_ps
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16i
preinterpret
<
Packet16i
,
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_castps_si512
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
preinterpret
<
Packet16f
,
Packet16i
>
(
const
Packet16i
&
a
)
{
return
_mm512_castsi512_ps
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
preinterpret
<
Packet8d
,
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_castps_pd
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
preinterpret
<
Packet8d
,
Packet8l
>
(
const
Packet8l
&
a
)
{
return
_mm512_castsi512_pd
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8l
preinterpret
<
Packet8l
,
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_castpd_si512
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
preinterpret
<
Packet16f
,
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_castpd_ps
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8f
preinterpret
<
Packet8f
,
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_castps512_ps256
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet4f
preinterpret
<
Packet4f
,
Packet16f
>
(
const
Packet16f
&
a
)
{
return
_mm512_castps512_ps128
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet4d
preinterpret
<
Packet4d
,
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_castpd512_pd256
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet2d
preinterpret
<
Packet2d
,
Packet8d
>
(
const
Packet8d
&
a
)
{
return
_mm512_castpd512_pd128
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
preinterpret
<
Packet16f
,
Packet8f
>
(
const
Packet8f
&
a
)
{
return
_mm512_castps256_ps512
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
preinterpret
<
Packet16f
,
Packet4f
>
(
const
Packet4f
&
a
)
{
return
_mm512_castps128_ps512
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
preinterpret
<
Packet8d
,
Packet4d
>
(
const
Packet4d
&
a
)
{
return
_mm512_castpd256_pd512
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8d
preinterpret
<
Packet8d
,
Packet2d
>
(
const
Packet2d
&
a
)
{
return
_mm512_castpd128_pd512
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8i
preinterpret
<
Packet8i
,
Packet16i
>
(
const
Packet16i
&
a
)
{
return
_mm512_castsi512_si256
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet4i
preinterpret
<
Packet4i
,
Packet16i
>
(
const
Packet16i
&
a
)
{
return
_mm512_castsi512_si128
(
a
);
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template
<
>
EIGEN_STRONG_INLINE
Packet8h
preinterpret
<
Packet8h
,
Packet16h
>
(
const
Packet16h
&
a
)
{
return
_mm256_castsi256_si128
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcast
<
Packet16h
,
Packet16f
>
(
const
Packet16h
&
a
)
{
return
half2float
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcast
<
Packet16f
,
Packet16h
>
(
const
Packet16f
&
a
)
{
return
float2half
(
a
);
}
#endif
template
<
>
EIGEN_STRONG_INLINE
Packet8bf
preinterpret
<
Packet8bf
,
Packet16bf
>
(
const
Packet16bf
&
a
)
{
return
_mm256_castsi256_si128
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcast
<
Packet16bf
,
Packet16f
>
(
const
Packet16bf
&
a
)
{
return
Bf16ToF32
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16bf
pcast
<
Packet16f
,
Packet16bf
>
(
const
Packet16f
&
a
)
{
return
F32ToBf16
(
a
);
}
}
// end namespace internal
}
// end namespace Eigen
#endif // EIGEN_TYPE_CASTING_AVX512_H
eigen-master/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h
0 → 100644
View file @
266d4fd9
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TYPE_CASTING_FP16_AVX512_H
#define EIGEN_TYPE_CASTING_FP16_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace
Eigen
{
namespace
internal
{
template
<
>
EIGEN_STRONG_INLINE
Packet32s
preinterpret
<
Packet32s
,
Packet32h
>
(
const
Packet32h
&
a
)
{
return
_mm512_castph_si512
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16s
preinterpret
<
Packet16s
,
Packet16h
>
(
const
Packet16h
&
a
)
{
return
_mm256_castph_si256
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8s
preinterpret
<
Packet8s
,
Packet8h
>
(
const
Packet8h
&
a
)
{
return
_mm_castph_si128
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet32h
preinterpret
<
Packet32h
,
Packet32s
>
(
const
Packet32s
&
a
)
{
return
_mm512_castsi512_ph
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
preinterpret
<
Packet16h
,
Packet16s
>
(
const
Packet16s
&
a
)
{
return
_mm256_castsi256_ph
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
preinterpret
<
Packet8h
,
Packet8s
>
(
const
Packet8s
&
a
)
{
return
_mm_castsi128_ph
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcast
<
Packet16h
,
Packet16f
>
(
const
Packet16h
&
a
)
{
return
half2float
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8f
pcast
<
Packet8h
,
Packet8f
>
(
const
Packet8h
&
a
)
{
return
half2float
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcast
<
Packet16f
,
Packet16h
>
(
const
Packet16f
&
a
)
{
return
float2half
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pcast
<
Packet8f
,
Packet8h
>
(
const
Packet8f
&
a
)
{
return
float2half
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16f
pcast
<
Packet32h
,
Packet16f
>
(
const
Packet32h
&
a
)
{
// Discard second-half of input.
Packet16h
low
=
_mm256_castpd_ph
(
_mm512_extractf64x4_pd
(
_mm512_castph_pd
(
a
),
0
));
return
_mm512_cvtxph_ps
(
low
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8f
pcast
<
Packet16h
,
Packet8f
>
(
const
Packet16h
&
a
)
{
// Discard second-half of input.
Packet8h
low
=
_mm_castps_ph
(
_mm256_extractf32x4_ps
(
_mm256_castph_ps
(
a
),
0
));
return
_mm256_cvtxph_ps
(
low
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet4f
pcast
<
Packet8h
,
Packet4f
>
(
const
Packet8h
&
a
)
{
Packet8f
full
=
_mm256_cvtxph_ps
(
a
);
// Discard second-half of input.
return
_mm256_extractf32x4_ps
(
full
,
0
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pcast
<
Packet16f
,
Packet32h
>
(
const
Packet16f
&
a
,
const
Packet16f
&
b
)
{
__m512
result
=
_mm512_castsi512_ps
(
_mm512_castsi256_si512
(
_mm256_castph_si256
(
_mm512_cvtxps_ph
(
a
))));
result
=
_mm512_insertf32x8
(
result
,
_mm256_castph_ps
(
_mm512_cvtxps_ph
(
b
)),
1
);
return
_mm512_castps_ph
(
result
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcast
<
Packet8f
,
Packet16h
>
(
const
Packet8f
&
a
,
const
Packet8f
&
b
)
{
__m256
result
=
_mm256_castsi256_ps
(
_mm256_castsi128_si256
(
_mm_castph_si128
(
_mm256_cvtxps_ph
(
a
))));
result
=
_mm256_insertf32x4
(
result
,
_mm_castph_ps
(
_mm256_cvtxps_ph
(
b
)),
1
);
return
_mm256_castps_ph
(
result
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pcast
<
Packet4f
,
Packet8h
>
(
const
Packet4f
&
a
,
const
Packet4f
&
b
)
{
__m256
result
=
_mm256_castsi256_ps
(
_mm256_castsi128_si256
(
_mm_castps_si128
(
a
)));
result
=
_mm256_insertf128_ps
(
result
,
b
,
1
);
return
_mm256_cvtxps_ph
(
result
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet32s
pcast
<
Packet32h
,
Packet32s
>
(
const
Packet32h
&
a
)
{
return
_mm512_cvtph_epi16
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16s
pcast
<
Packet16h
,
Packet16s
>
(
const
Packet16h
&
a
)
{
return
_mm256_cvtph_epi16
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8s
pcast
<
Packet8h
,
Packet8s
>
(
const
Packet8h
&
a
)
{
return
_mm_cvtph_epi16
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet32h
pcast
<
Packet32s
,
Packet32h
>
(
const
Packet32s
&
a
)
{
return
_mm512_cvtepi16_ph
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet16h
pcast
<
Packet16s
,
Packet16h
>
(
const
Packet16s
&
a
)
{
return
_mm256_cvtepi16_ph
(
a
);
}
template
<
>
EIGEN_STRONG_INLINE
Packet8h
pcast
<
Packet8s
,
Packet8h
>
(
const
Packet8s
&
a
)
{
return
_mm_cvtepi16_ph
(
a
);
}
}
// namespace internal
}
// namespace Eigen
#endif // EIGEN_TYPE_CASTING_FP16_AVX512_H
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment