Commit 266d4fd9 authored by zhanggzh's avatar zhanggzh
Browse files

add lietorch src code and eigen src code, update readme

parent e7df8655
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2016 Pedro Gonnet (pedro.gonnet@gmail.com)
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_
#define THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_FLOAT(Packet16f)
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(Packet8d)
template <>
EIGEN_STRONG_INLINE Packet16h pfrexp(const Packet16h& a, Packet16h& exponent) {
Packet16f fexponent;
const Packet16h out = float2half(pfrexp<Packet16f>(half2float(a), fexponent));
exponent = float2half(fexponent);
return out;
}
template <>
EIGEN_STRONG_INLINE Packet16h pldexp(const Packet16h& a, const Packet16h& exponent) {
return float2half(pldexp<Packet16f>(half2float(a), half2float(exponent)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pfrexp(const Packet16bf& a, Packet16bf& exponent) {
Packet16f fexponent;
const Packet16bf out = F32ToBf16(pfrexp<Packet16f>(Bf16ToF32(a), fexponent));
exponent = F32ToBf16(fexponent);
return out;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exponent) {
return F32ToBf16(pldexp<Packet16f>(Bf16ToF32(a), Bf16ToF32(exponent)));
}
#if EIGEN_FAST_MATH
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f psqrt<Packet16f>(const Packet16f& x) {
return generic_sqrt_newton_step<Packet16f>::run(x, _mm512_rsqrt14_ps(x));
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d psqrt<Packet8d>(const Packet8d& x) {
#ifdef EIGEN_VECTORIZE_AVX512ER
return generic_sqrt_newton_step<Packet8d, /*Steps=*/1>::run(x, _mm512_rsqrt28_pd(x));
#else
return generic_sqrt_newton_step<Packet8d, /*Steps=*/2>::run(x, _mm512_rsqrt14_pd(x));
#endif
}
#else
template <>
EIGEN_STRONG_INLINE Packet16f psqrt<Packet16f>(const Packet16f& x) {
return _mm512_sqrt_ps(x);
}
template <>
EIGEN_STRONG_INLINE Packet8d psqrt<Packet8d>(const Packet8d& x) {
return _mm512_sqrt_pd(x);
}
#endif
// prsqrt for float.
#if defined(EIGEN_VECTORIZE_AVX512ER)
template <>
EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
return _mm512_rsqrt28_ps(x);
}
#elif EIGEN_FAST_MATH
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f prsqrt<Packet16f>(const Packet16f& x) {
return generic_rsqrt_newton_step<Packet16f, /*Steps=*/1>::run(x, _mm512_rsqrt14_ps(x));
}
#endif
// prsqrt for double.
#if EIGEN_FAST_MATH
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d prsqrt<Packet8d>(const Packet8d& x) {
#ifdef EIGEN_VECTORIZE_AVX512ER
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/1>::run(x, _mm512_rsqrt28_pd(x));
#else
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/2>::run(x, _mm512_rsqrt14_pd(x));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16f preciprocal<Packet16f>(const Packet16f& a) {
#ifdef EIGEN_VECTORIZE_AVX512ER
return _mm512_rcp28_ps(a);
#else
return generic_reciprocal_newton_step<Packet16f, /*Steps=*/1>::run(a, _mm512_rcp14_ps(a));
#endif
}
#endif
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp2)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog2)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, preciprocal)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
#ifndef EIGEN_VECTORIZE_AVX512FP16
F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp2)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1)
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog)
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog1p)
F16_PACKET_FUNCTION(Packet16f, Packet16h, plog2)
F16_PACKET_FUNCTION(Packet16f, Packet16h, preciprocal)
F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt)
F16_PACKET_FUNCTION(Packet16f, Packet16h, psin)
F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt)
F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh)
#endif // EIGEN_VECTORIZE_AVX512FP16
} // end namespace internal
} // end namespace Eigen
#endif // THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
#define EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) {
__m512i result = _mm512_castsi256_si512(_mm256_castph_si256(a));
result = _mm512_inserti64x4(result, _mm256_castph_si256(b), 1);
return _mm512_castsi512_ph(result);
}
EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) {
a = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_castph_si512(x)));
b = _mm256_castsi256_ph(_mm512_extracti64x4_epi64(_mm512_castph_si512(x), 1));
}
#define _EIGEN_GENERATE_FP16_MATH_FUNCTION(func) \
template <> \
EIGEN_STRONG_INLINE Packet8h func<Packet8h>(const Packet8h& a) { \
return float2half(func(half2float(a))); \
} \
\
template <> \
EIGEN_STRONG_INLINE Packet16h func<Packet16h>(const Packet16h& a) { \
return float2half(func(half2float(a))); \
} \
\
template <> \
EIGEN_STRONG_INLINE Packet32h func<Packet32h>(const Packet32h& a) { \
Packet16h low; \
Packet16h high; \
extract2Packet16h(a, low, high); \
return combine2Packet16h(func(low), func(high)); \
}
_EIGEN_GENERATE_FP16_MATH_FUNCTION(psin)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pcos)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog2)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog1p)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexpm1)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp2)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(ptanh)
#undef _EIGEN_GENERATE_FP16_MATH_FUNCTION
// pfrexp
template <>
EIGEN_STRONG_INLINE Packet32h pfrexp<Packet32h>(const Packet32h& a, Packet32h& exponent) {
return pfrexp_generic(a, exponent);
}
// pldexp
template <>
EIGEN_STRONG_INLINE Packet32h pldexp<Packet32h>(const Packet32h& a, const Packet32h& exponent) {
return pldexp_generic(a, exponent);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
\ No newline at end of file
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2016 Benoit Steiner (benoit.steiner.goog@gmail.com)
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_PACKET_MATH_AVX512_H
#define EIGEN_PACKET_MATH_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
#endif
#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
#endif
#ifdef EIGEN_VECTORIZE_FMA
#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
#endif
#endif
typedef __m512 Packet16f;
typedef __m512i Packet16i;
typedef __m512d Packet8d;
typedef eigen_packet_wrapper<__m512i, 1> Packet8l;
#ifndef EIGEN_VECTORIZE_AVX512FP16
typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
#endif
typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
typedef eigen_packet_wrapper<__m512i, 6> Packet32s;
typedef eigen_packet_wrapper<__m256i, 6> Packet16s;
typedef eigen_packet_wrapper<__m128i, 6> Packet8s;
template <>
struct is_arithmetic<__m512> {
enum { value = true };
};
template <>
struct is_arithmetic<__m512i> {
enum { value = true };
};
template <>
struct is_arithmetic<__m512d> {
enum { value = true };
};
template <>
struct is_arithmetic<Packet8l> {
enum { value = true };
};
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
struct is_arithmetic<Packet16h> {
enum { value = true };
};
template <>
struct packet_traits<half> : default_packet_traits {
typedef Packet16h type;
// There is no half-size packet for Packet16h.
typedef Packet16h half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
HasCmp = 1,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasAbs = 1,
HasAbs2 = 0,
HasMin = 1,
HasMax = 1,
HasConj = 1,
HasSetLinear = 0,
HasSqrt = 1,
HasRsqrt = 1,
HasLog = 1,
HasLog1p = 1,
HasExp = 1,
HasExpm1 = 1,
HasBessel = 1,
HasNdtri = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasBlend = 0
};
};
#endif
template <>
struct packet_traits<float> : default_packet_traits {
typedef Packet16f type;
typedef Packet8f half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
HasAbs = 1,
HasMin = 1,
HasMax = 1,
HasConj = 1,
HasBlend = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasACos = 1,
HasASin = 1,
HasATan = 1,
HasATanh = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasNdtri = 1,
HasBessel = 1,
HasExp = 1,
HasReciprocal = EIGEN_FAST_MATH,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasErfc = EIGEN_FAST_MATH,
HasCmp = 1,
HasDiv = 1
};
};
template <>
struct packet_traits<double> : default_packet_traits {
typedef Packet8d type;
typedef Packet4d half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 8,
HasBlend = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
HasATan = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasErfc = EIGEN_FAST_MATH,
HasATanh = 1,
HasCmp = 1,
HasDiv = 1
};
};
template <>
struct packet_traits<int> : default_packet_traits {
typedef Packet16i type;
typedef Packet8i half;
enum { Vectorizable = 1, AlignedOnScalar = 1, HasBlend = 0, HasCmp = 1, HasDiv = 1, size = 16 };
};
template <>
struct packet_traits<int64_t> : default_packet_traits {
typedef Packet8l type;
typedef Packet4l half;
enum { Vectorizable = 1, AlignedOnScalar = 1, HasCmp = 1, size = 8 };
};
template <>
struct unpacket_traits<Packet16f> {
typedef float type;
typedef Packet8f half;
typedef Packet16i integer_packet;
typedef uint16_t mask_t;
enum {
size = 16,
alignment = Aligned64,
vectorizable = true,
masked_load_available = true,
masked_store_available = true,
masked_fpops_available = true
};
};
template <>
struct unpacket_traits<Packet8d> {
typedef double type;
typedef Packet4d half;
typedef Packet8l integer_packet;
typedef uint8_t mask_t;
enum {
size = 8,
alignment = Aligned64,
vectorizable = true,
masked_load_available = true,
masked_store_available = true,
masked_fpops_available = true
};
};
template <>
struct unpacket_traits<Packet16i> {
typedef int type;
typedef Packet8i half;
enum {
size = 16,
alignment = Aligned64,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
struct unpacket_traits<Packet8l> {
typedef int64_t type;
typedef Packet4l half;
enum {
size = 8,
alignment = Aligned64,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
struct unpacket_traits<Packet16h> {
typedef Eigen::half type;
typedef Packet8h half;
enum {
size = 16,
alignment = Aligned32,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
#endif
template <>
struct unpacket_traits<Packet32s> {
typedef numext::int16_t type;
typedef Packet16s half;
enum {
size = 32,
alignment = Aligned64,
vectorizable = false,
};
};
template <>
struct unpacket_traits<Packet16s> {
typedef numext::int16_t type;
typedef Packet8s half;
enum {
size = 16,
alignment = Aligned32,
vectorizable = false,
};
};
template <>
struct unpacket_traits<Packet8s> {
typedef numext::int16_t type;
typedef Packet8s half;
enum {
size = 8,
alignment = Aligned16,
vectorizable = false,
};
};
template <>
EIGEN_STRONG_INLINE Packet16f pset1<Packet16f>(const float& from) {
return _mm512_set1_ps(from);
}
template <>
EIGEN_STRONG_INLINE Packet8d pset1<Packet8d>(const double& from) {
return _mm512_set1_pd(from);
}
template <>
EIGEN_STRONG_INLINE Packet16i pset1<Packet16i>(const int& from) {
return _mm512_set1_epi32(from);
}
template <>
EIGEN_STRONG_INLINE Packet8l pset1<Packet8l>(const int64_t& from) {
return _mm512_set1_epi64(from);
}
template <>
EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) {
return _mm512_castsi512_ps(_mm512_set1_epi32(from));
}
template <>
EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(const numext::uint64_t from) {
return _mm512_castsi512_pd(_mm512_set1_epi64(from));
}
template <>
EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) {
return _mm512_setzero_ps();
}
template <>
EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) {
return _mm512_setzero_pd();
}
template <>
EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) {
return _mm512_setzero_si512();
}
template <>
EIGEN_STRONG_INLINE Packet8l pzero(const Packet8l& /*a*/) {
return _mm512_setzero_si512();
}
template <>
EIGEN_STRONG_INLINE Packet16f peven_mask(const Packet16f& /*a*/) {
return _mm512_castsi512_ps(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1));
}
template <>
EIGEN_STRONG_INLINE Packet16i peven_mask(const Packet16i& /*a*/) {
return _mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1);
}
template <>
EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) {
return _mm512_castsi512_pd(_mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1));
}
template <>
EIGEN_STRONG_INLINE Packet8l peven_mask(const Packet8l& /*a*/) {
return _mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1);
}
template <>
EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
// Inline asm here helps reduce some register spilling in TRSM kernels.
// See note in unrolls::gemm::microKernel in TrsmKernel.h
Packet16f ret;
__asm__("vbroadcastss %[mem], %[dst]" : [dst] "=v"(ret) : [mem] "m"(*from));
return ret;
#else
return _mm512_broadcastss_ps(_mm_load_ps1(from));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet8d pload1<Packet8d>(const double* from) {
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
Packet8d ret;
__asm__("vbroadcastsd %[mem], %[dst]" : [dst] "=v"(ret) : [mem] "m"(*from));
return ret;
#else
return _mm512_set1_pd(*from);
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16f plset<Packet16f>(const float& a) {
return _mm512_add_ps(_mm512_set1_ps(a), _mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f));
}
template <>
EIGEN_STRONG_INLINE Packet8d plset<Packet8d>(const double& a) {
return _mm512_add_pd(_mm512_set1_pd(a), _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0));
}
template <>
EIGEN_STRONG_INLINE Packet16i plset<Packet16i>(const int& a) {
return _mm512_add_epi32(_mm512_set1_epi32(a), _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
}
template <>
EIGEN_STRONG_INLINE Packet8l plset<Packet8l>(const int64_t& a) {
return _mm512_add_epi64(_mm512_set1_epi64(a), _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0));
}
template <>
EIGEN_STRONG_INLINE Packet16f padd<Packet16f>(const Packet16f& a, const Packet16f& b) {
return _mm512_add_ps(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a, const Packet8d& b) {
return _mm512_add_pd(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16i padd<Packet16i>(const Packet16i& a, const Packet16i& b) {
return _mm512_add_epi32(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8l padd<Packet8l>(const Packet8l& a, const Packet8l& b) {
return _mm512_add_epi64(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f padd<Packet16f>(const Packet16f& a, const Packet16f& b, uint16_t umask) {
__mmask16 mask = static_cast<__mmask16>(umask);
return _mm512_maskz_add_ps(mask, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a, const Packet8d& b, uint8_t umask) {
__mmask8 mask = static_cast<__mmask8>(umask);
return _mm512_maskz_add_pd(mask, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a, const Packet16f& b) {
return _mm512_sub_ps(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8d psub<Packet8d>(const Packet8d& a, const Packet8d& b) {
return _mm512_sub_pd(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16i psub<Packet16i>(const Packet16i& a, const Packet16i& b) {
return _mm512_sub_epi32(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8l psub<Packet8l>(const Packet8l& a, const Packet8l& b) {
return _mm512_sub_epi64(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f pnegate(const Packet16f& a) {
// NOTE: MSVC seems to struggle with _mm512_set1_epi32, leading to random results.
// The intel docs give it a relatively high latency as well, so we're probably
// better off with using _mm512_set_epi32 directly anyways.
const __m512i mask =
_mm512_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000);
return _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a), mask));
}
template <>
EIGEN_STRONG_INLINE Packet8d pnegate(const Packet8d& a) {
const __m512i mask =
_mm512_set_epi64(0x8000000000000000ULL, 0x8000000000000000ULL, 0x8000000000000000ULL, 0x8000000000000000ULL,
0x8000000000000000ULL, 0x8000000000000000ULL, 0x8000000000000000ULL, 0x8000000000000000ULL);
return _mm512_castsi512_pd(_mm512_xor_epi64(_mm512_castpd_si512(a), mask));
}
template <>
EIGEN_STRONG_INLINE Packet16i pnegate(const Packet16i& a) {
return _mm512_sub_epi32(_mm512_setzero_si512(), a);
}
template <>
EIGEN_STRONG_INLINE Packet8l pnegate(const Packet8l& a) {
return _mm512_sub_epi64(_mm512_setzero_si512(), a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pconj(const Packet16f& a) {
return a;
}
template <>
EIGEN_STRONG_INLINE Packet8d pconj(const Packet8d& a) {
return a;
}
template <>
EIGEN_STRONG_INLINE Packet16i pconj(const Packet16i& a) {
return a;
}
template <>
EIGEN_STRONG_INLINE Packet8l pconj(const Packet8l& a) {
return a;
}
template <>
EIGEN_STRONG_INLINE Packet16f pmul<Packet16f>(const Packet16f& a, const Packet16f& b) {
return _mm512_mul_ps(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a, const Packet8d& b) {
return _mm512_mul_pd(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a, const Packet16i& b) {
return _mm512_mullo_epi32(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8l pmul<Packet8l>(const Packet8l& a, const Packet8l& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_mullo_epi64(a, b);
#else
return _mm512_mullox_epi64(a, b);
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16f pdiv<Packet16f>(const Packet16f& a, const Packet16f& b) {
return _mm512_div_ps(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8d pdiv<Packet8d>(const Packet8d& a, const Packet8d& b) {
return _mm512_div_pd(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16i pdiv<Packet16i>(const Packet16i& a, const Packet16i& b) {
Packet8i q_lo = pdiv<Packet8i>(_mm512_extracti64x4_epi64(a, 0), _mm512_extracti64x4_epi64(b, 0));
Packet8i q_hi = pdiv<Packet8i>(_mm512_extracti64x4_epi64(a, 1), _mm512_extracti64x4_epi64(b, 1));
return _mm512_inserti64x4(_mm512_castsi256_si512(q_lo), q_hi, 1);
}
#ifdef EIGEN_VECTORIZE_FMA
template <>
EIGEN_STRONG_INLINE Packet16f pmadd(const Packet16f& a, const Packet16f& b, const Packet16f& c) {
return _mm512_fmadd_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b, const Packet8d& c) {
return _mm512_fmadd_pd(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16f pmsub(const Packet16f& a, const Packet16f& b, const Packet16f& c) {
return _mm512_fmsub_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmsub(const Packet8d& a, const Packet8d& b, const Packet8d& c) {
return _mm512_fmsub_pd(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16f pnmadd(const Packet16f& a, const Packet16f& b, const Packet16f& c) {
return _mm512_fnmadd_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8d pnmadd(const Packet8d& a, const Packet8d& b, const Packet8d& c) {
return _mm512_fnmadd_pd(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16f pnmsub(const Packet16f& a, const Packet16f& b, const Packet16f& c) {
return _mm512_fnmsub_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8d pnmsub(const Packet8d& a, const Packet8d& b, const Packet8d& c) {
return _mm512_fnmsub_pd(a, b, c);
}
#endif
template <>
EIGEN_DEVICE_FUNC inline Packet16f pselect(const Packet16f& mask, const Packet16f& a, const Packet16f& b) {
__mmask16 mask16 = _mm512_cmpeq_epi32_mask(_mm512_castps_si512(mask), _mm512_setzero_epi32());
return _mm512_mask_blend_ps(mask16, a, b);
}
template <>
EIGEN_DEVICE_FUNC inline Packet16i pselect(const Packet16i& mask, const Packet16i& a, const Packet16i& b) {
__mmask16 mask16 = _mm512_cmpeq_epi32_mask(mask, _mm512_setzero_epi32());
return _mm512_mask_blend_epi32(mask16, a, b);
}
template <>
EIGEN_DEVICE_FUNC inline Packet8l pselect(const Packet8l& mask, const Packet8l& a, const Packet8l& b) {
__mmask8 mask8 = _mm512_cmpeq_epi64_mask(mask, _mm512_setzero_si512());
return _mm512_mask_blend_epi64(mask8, a, b);
}
template <>
EIGEN_DEVICE_FUNC inline Packet8d pselect(const Packet8d& mask, const Packet8d& a, const Packet8d& b) {
__mmask8 mask8 = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
return _mm512_mask_blend_pd(mask8, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f pmin<Packet16f>(const Packet16f& a, const Packet16f& b) {
// Arguments are reversed to match NaN propagation behavior of std::min.
return _mm512_min_ps(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmin<Packet8d>(const Packet8d& a, const Packet8d& b) {
// Arguments are reversed to match NaN propagation behavior of std::min.
return _mm512_min_pd(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet16i pmin<Packet16i>(const Packet16i& a, const Packet16i& b) {
return _mm512_min_epi32(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet8l pmin<Packet8l>(const Packet8l& a, const Packet8l& b) {
return _mm512_min_epi64(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pmax<Packet16f>(const Packet16f& a, const Packet16f& b) {
// Arguments are reversed to match NaN propagation behavior of std::max.
return _mm512_max_ps(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmax<Packet8d>(const Packet8d& a, const Packet8d& b) {
// Arguments are reversed to match NaN propagation behavior of std::max.
return _mm512_max_pd(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet16i pmax<Packet16i>(const Packet16i& a, const Packet16i& b) {
return _mm512_max_epi32(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet8l pmax<Packet8l>(const Packet8l& a, const Packet8l& b) {
return _mm512_max_epi64(b, a);
}
// Add specializations for min/max with prescribed NaN propagation.
template <>
EIGEN_STRONG_INLINE Packet16f pmin<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) {
return pminmax_propagate_numbers(a, b, pmin<Packet16f>);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmin<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) {
return pminmax_propagate_numbers(a, b, pmin<Packet8d>);
}
template <>
EIGEN_STRONG_INLINE Packet16f pmax<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) {
return pminmax_propagate_numbers(a, b, pmax<Packet16f>);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmax<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) {
return pminmax_propagate_numbers(a, b, pmax<Packet8d>);
}
template <>
EIGEN_STRONG_INLINE Packet16f pmin<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) {
return pminmax_propagate_nan(a, b, pmin<Packet16f>);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmin<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) {
return pminmax_propagate_nan(a, b, pmin<Packet8d>);
}
template <>
EIGEN_STRONG_INLINE Packet16f pmax<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) {
return pminmax_propagate_nan(a, b, pmax<Packet16f>);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmax<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) {
return pminmax_propagate_nan(a, b, pmax<Packet8d>);
}
#ifdef EIGEN_VECTORIZE_AVX512DQ
template <int I_>
EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) {
return _mm512_extractf32x8_ps(x, I_);
}
template <int I_>
EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) {
return _mm512_extractf64x2_pd(x, I_);
}
EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
}
EIGEN_STRONG_INLINE Packet16i cat256i(Packet8i a, Packet8i b) {
return _mm512_inserti32x8(_mm512_castsi256_si512(a), b, 1);
}
#else
// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
template <int I_>
EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) {
return _mm256_castsi256_ps(_mm512_extracti64x4_epi64(_mm512_castps_si512(x), I_));
}
// AVX512F does not define _mm512_extractf64x2_pd to extract _m128 from _m512
template <int I_>
EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) {
return _mm_castsi128_pd(_mm512_extracti32x4_epi32(_mm512_castpd_si512(x), I_));
}
EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
return _mm512_castsi512_ps(
_mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)), _mm256_castps_si256(b), 1));
}
EIGEN_STRONG_INLINE Packet16i cat256i(Packet8i a, Packet8i b) {
return _mm512_inserti64x4(_mm512_castsi256_si512(a), b, 1);
}
#endif
// Helper function for bit packing snippet of low precision comparison.
// It packs the flags from 32x16 to 16x16.
EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) {
// Split data into small pieces and handle with AVX instructions
// to guarantee internal order of vector.
// Operation:
// dst[15:0] := Saturate16(rf[31:0])
// dst[31:16] := Saturate16(rf[63:32])
// ...
// dst[255:240] := Saturate16(rf[255:224])
__m256i lo = _mm256_castps_si256(extract256<0>(rf));
__m256i hi = _mm256_castps_si256(extract256<1>(rf));
__m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0), _mm256_extractf128_si256(lo, 1));
__m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0), _mm256_extractf128_si256(hi, 1));
return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
}
template <>
EIGEN_STRONG_INLINE Packet16f pisnan(const Packet16f& a) {
__mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_UNORD_Q);
return _mm512_castsi512_ps(_mm512_maskz_set1_epi32(mask, int32_t(-1)));
}
template <>
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)));
}
template <>
EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)));
}
template <>
EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)));
}
template <>
EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ);
return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)));
}
template <>
EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) {
__mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ);
return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1));
}
template <>
EIGEN_STRONG_INLINE Packet16i pcmp_le(const Packet16i& a, const Packet16i& b) {
__mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LE);
return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1));
}
template <>
EIGEN_STRONG_INLINE Packet16i pcmp_lt(const Packet16i& a, const Packet16i& b) {
__mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT);
return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1));
}
template <>
EIGEN_STRONG_INLINE Packet8l pcmp_eq(const Packet8l& a, const Packet8l& b) {
__mmask8 mask = _mm512_cmp_epi64_mask(a, b, _MM_CMPINT_EQ);
return _mm512_mask_set1_epi64(_mm512_setzero_si512(), mask, int64_t(-1));
}
template <>
EIGEN_STRONG_INLINE Packet8l pcmp_le(const Packet8l& a, const Packet8l& b) {
__mmask8 mask = _mm512_cmp_epi64_mask(a, b, _MM_CMPINT_LE);
return _mm512_mask_set1_epi64(_mm512_setzero_si512(), mask, int64_t(-1));
}
template <>
EIGEN_STRONG_INLINE Packet8l pcmp_lt(const Packet8l& a, const Packet8l& b) {
__mmask8 mask = _mm512_cmp_epi64_mask(a, b, _MM_CMPINT_LT);
return _mm512_mask_set1_epi64(_mm512_setzero_si512(), mask, int64_t(-1));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ);
return _mm512_castsi512_pd(_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
}
template <>
EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) {
return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION);
}
template <>
EIGEN_STRONG_INLINE Packet8d print<Packet8d>(const Packet8d& a) {
return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION);
}
template <>
EIGEN_STRONG_INLINE Packet16f pceil<Packet16f>(const Packet16f& a) {
return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF);
}
template <>
EIGEN_STRONG_INLINE Packet8d pceil<Packet8d>(const Packet8d& a) {
return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF);
}
template <>
EIGEN_STRONG_INLINE Packet16f pfloor<Packet16f>(const Packet16f& a) {
return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF);
}
template <>
EIGEN_STRONG_INLINE Packet8d pfloor<Packet8d>(const Packet8d& a) {
return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF);
}
template <>
EIGEN_STRONG_INLINE Packet16f ptrunc<Packet16f>(const Packet16f& a) {
return _mm512_roundscale_ps(a, _MM_FROUND_TO_ZERO);
}
template <>
EIGEN_STRONG_INLINE Packet8d ptrunc<Packet8d>(const Packet8d& a) {
return _mm512_roundscale_pd(a, _MM_FROUND_TO_ZERO);
}
template <>
EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {
return _mm512_set1_epi32(int32_t(-1));
}
template <>
EIGEN_STRONG_INLINE Packet8l ptrue<Packet8l>(const Packet8l& /*a*/) {
return _mm512_set1_epi64(int64_t(-1));
}
template <>
EIGEN_STRONG_INLINE Packet16f ptrue<Packet16f>(const Packet16f& a) {
return _mm512_castsi512_ps(ptrue<Packet16i>(_mm512_castps_si512(a)));
}
template <>
EIGEN_STRONG_INLINE Packet8d ptrue<Packet8d>(const Packet8d& a) {
return _mm512_castsi512_pd(ptrue<Packet16i>(_mm512_castpd_si512(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16i pand<Packet16i>(const Packet16i& a, const Packet16i& b) {
return _mm512_and_si512(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8l pand<Packet8l>(const Packet8l& a, const Packet8l& b) {
return _mm512_and_si512(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a, const Packet16f& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_and_ps(a, b);
#else
return _mm512_castsi512_ps(pand(_mm512_castps_si512(a), _mm512_castps_si512(b)));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet8d pand<Packet8d>(const Packet8d& a, const Packet8d& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_and_pd(a, b);
#else
Packet8d res = _mm512_undefined_pd();
Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0);
Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0);
res = _mm512_insertf64x4(res, _mm256_and_pd(lane0_a, lane0_b), 0);
Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
return _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1);
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16i por<Packet16i>(const Packet16i& a, const Packet16i& b) {
return _mm512_or_si512(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8l por<Packet8l>(const Packet8l& a, const Packet8l& b) {
return _mm512_or_si512(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a, const Packet16f& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_or_ps(a, b);
#else
return _mm512_castsi512_ps(por(_mm512_castps_si512(a), _mm512_castps_si512(b)));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet8d por<Packet8d>(const Packet8d& a, const Packet8d& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_or_pd(a, b);
#else
return _mm512_castsi512_pd(por(_mm512_castpd_si512(a), _mm512_castpd_si512(b)));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16i pxor<Packet16i>(const Packet16i& a, const Packet16i& b) {
return _mm512_xor_si512(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8l pxor<Packet8l>(const Packet8l& a, const Packet8l& b) {
return _mm512_xor_si512(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a, const Packet16f& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_xor_ps(a, b);
#else
return _mm512_castsi512_ps(pxor(_mm512_castps_si512(a), _mm512_castps_si512(b)));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a, const Packet8d& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_xor_pd(a, b);
#else
return _mm512_castsi512_pd(pxor(_mm512_castpd_si512(a), _mm512_castpd_si512(b)));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16i pandnot<Packet16i>(const Packet16i& a, const Packet16i& b) {
return _mm512_andnot_si512(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet8l pandnot<Packet8l>(const Packet8l& a, const Packet8l& b) {
return _mm512_andnot_si512(b, a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a, const Packet16f& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_andnot_ps(b, a);
#else
return _mm512_castsi512_ps(pandnot(_mm512_castps_si512(a), _mm512_castps_si512(b)));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a, const Packet8d& b) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_andnot_pd(b, a);
#else
return _mm512_castsi512_pd(pandnot(_mm512_castpd_si512(a), _mm512_castpd_si512(b)));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16f pround<Packet16f>(const Packet16f& a) {
// Work-around for default std::round rounding mode.
const Packet16f mask = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x80000000u));
const Packet16f prev0dot5 = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x3EFFFFFFu));
return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
}
template <>
EIGEN_STRONG_INLINE Packet8d pround<Packet8d>(const Packet8d& a) {
// Work-around for default std::round rounding mode.
const Packet8d mask = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x8000000000000000ull));
const Packet8d prev0dot5 = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull));
return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
}
template <int N>
EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) {
return _mm512_srai_epi32(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet16i plogical_shift_right(Packet16i a) {
return _mm512_srli_epi32(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet16i plogical_shift_left(Packet16i a) {
return _mm512_slli_epi32(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet8l parithmetic_shift_right(Packet8l a) {
return _mm512_srai_epi64(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet8l plogical_shift_right(Packet8l a) {
return _mm512_srli_epi64(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet8l plogical_shift_left(Packet8l a) {
return _mm512_slli_epi64(a, N);
}
template <>
EIGEN_STRONG_INLINE Packet16f pload<Packet16f>(const float* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ps(from);
}
template <>
EIGEN_STRONG_INLINE Packet8d pload<Packet8d>(const double* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_pd(from);
}
template <>
EIGEN_STRONG_INLINE Packet16i pload<Packet16i>(const int* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_epi64(from);
}
template <>
EIGEN_STRONG_INLINE Packet8l pload<Packet8l>(const int64_t* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_epi64(from);
}
template <>
EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ps(from);
}
template <>
EIGEN_STRONG_INLINE Packet8d ploadu<Packet8d>(const double* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_pd(from);
}
template <>
EIGEN_STRONG_INLINE Packet16i ploadu<Packet16i>(const int* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_epi32(from);
}
template <>
EIGEN_STRONG_INLINE Packet8l ploadu<Packet8l>(const int64_t* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_epi64(from);
}
template <>
EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from, uint16_t umask) {
__mmask16 mask = static_cast<__mmask16>(umask);
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from);
}
template <>
EIGEN_STRONG_INLINE Packet8d ploadu<Packet8d>(const double* from, uint8_t umask) {
__mmask8 mask = static_cast<__mmask8>(umask);
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_pd(mask, from);
}
// Loads 8 floats from memory a returns the packet
// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
template <>
EIGEN_STRONG_INLINE Packet16f ploaddup<Packet16f>(const float* from) {
// an unaligned load is required here as there is no requirement
// on the alignment of input pointer 'from'
__m256i low_half = _mm256_castps_si256(_mm256_loadu_ps(from));
__m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half));
__m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0));
return pairs;
}
// Loads 4 doubles from memory a returns the packet {a0, a0, a1, a1, a2, a2, a3,
// a3}
template <>
EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) {
Packet8d tmp = _mm512_castpd256_pd512(ploadu<Packet4d>(from));
const Packet8l scatter_mask = _mm512_set_epi64(3, 3, 2, 2, 1, 1, 0, 0);
return _mm512_permutexvar_pd(scatter_mask, tmp);
}
// Loads 4 int64_t from memory a returns the packet {a0, a0, a1, a1, a2, a2, a3,
// a3}
template <>
EIGEN_STRONG_INLINE Packet8l ploaddup<Packet8l>(const int64_t* from) {
Packet8l tmp = _mm512_castsi256_si512(ploadu<Packet4l>(from));
const Packet8l scatter_mask = _mm512_set_epi64(3, 3, 2, 2, 1, 1, 0, 0);
return _mm512_permutexvar_epi64(scatter_mask, tmp);
}
// Loads 8 integers from memory and returns the packet
// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
template <>
EIGEN_STRONG_INLINE Packet16i ploaddup<Packet16i>(const int* from) {
__m256i low_half = _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
__m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half));
__m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0));
return _mm512_castps_si512(pairs);
}
// Loads 4 floats from memory a returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3}
template <>
EIGEN_STRONG_INLINE Packet16f ploadquad<Packet16f>(const float* from) {
Packet16f tmp = _mm512_castps128_ps512(ploadu<Packet4f>(from));
const Packet16i scatter_mask = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
return _mm512_permutexvar_ps(scatter_mask, tmp);
}
// Loads 2 doubles from memory a returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1}
template <>
EIGEN_STRONG_INLINE Packet8d ploadquad<Packet8d>(const double* from) {
__m256d lane0 = _mm256_set1_pd(*from);
__m256d lane1 = _mm256_set1_pd(*(from + 1));
__m512d tmp = _mm512_undefined_pd();
tmp = _mm512_insertf64x4(tmp, lane0, 0);
return _mm512_insertf64x4(tmp, lane1, 1);
}
// Loads 2 int64_t from memory a returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1}
template <>
EIGEN_STRONG_INLINE Packet8l ploadquad<Packet8l>(const int64_t* from) {
__m256i lane0 = _mm256_set1_epi64x(*from);
__m256i lane1 = _mm256_set1_epi64x(*(from + 1));
__m512i tmp = _mm512_undefined_epi32();
tmp = _mm512_inserti64x4(tmp, lane0, 0);
return _mm512_inserti64x4(tmp, lane1, 1);
}
// Loads 4 integers from memory and returns the packet
// {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3}
template <>
EIGEN_STRONG_INLINE Packet16i ploadquad<Packet16i>(const int* from) {
Packet16i tmp = _mm512_castsi128_si512(ploadu<Packet4i>(from));
const Packet16i scatter_mask = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
return _mm512_permutexvar_epi32(scatter_mask, tmp);
}
template <>
EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet16f& from) {
EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ps(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet8d& from) {
EIGEN_DEBUG_ALIGNED_STORE _mm512_store_pd(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet16i& from) {
EIGEN_DEBUG_ALIGNED_STORE _mm512_store_epi32(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<int64_t>(int64_t* to, const Packet8l& from) {
EIGEN_DEBUG_ALIGNED_STORE _mm512_store_epi64(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet16f& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ps(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_pd(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet16i& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_epi32(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<int64_t>(int64_t* to, const Packet8l& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_epi64(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet16f& from, uint16_t umask) {
__mmask16 mask = static_cast<__mmask16>(umask);
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from, uint8_t umask) {
__mmask8 mask = static_cast<__mmask8>(umask);
EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from);
}
template <typename Scalar, typename Packet>
EIGEN_DEVICE_FUNC inline Packet pgather(const Packet& src, const Scalar* from, Index stride,
typename unpacket_traits<Packet>::mask_t umask);
template <>
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const Packet16f& src, const float* from, Index stride,
uint16_t umask) {
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
__mmask16 mask = static_cast<__mmask16>(umask);
return _mm512_mask_i32gather_ps(src, mask, indices, from, 4);
}
template <>
EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const Packet8d& src, const double* from, Index stride,
uint8_t umask) {
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
__mmask8 mask = static_cast<__mmask8>(umask);
return _mm512_mask_i32gather_pd(src, mask, indices, from, 8);
}
template <>
EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from, Index stride) {
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
return _mm512_i32gather_ps(indices, from, 4);
}
template <>
EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const double* from, Index stride) {
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
return _mm512_i32gather_pd(indices, from, 8);
}
template <>
EIGEN_DEVICE_FUNC inline Packet8l pgather<int64_t, Packet8l>(const int64_t* from, Index stride) {
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
return _mm512_i32gather_epi64(indices, from, 8);
}
template <>
EIGEN_DEVICE_FUNC inline Packet16i pgather<int, Packet16i>(const int* from, Index stride) {
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
return _mm512_i32gather_epi32(indices, from, 4);
}
template <typename Scalar, typename Packet>
EIGEN_DEVICE_FUNC inline void pscatter(Scalar* to, const Packet& from, Index stride,
typename unpacket_traits<Packet>::mask_t umask);
template <>
EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to, const Packet16f& from, Index stride,
uint16_t umask) {
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
__mmask16 mask = static_cast<__mmask16>(umask);
_mm512_mask_i32scatter_ps(to, mask, indices, from, 4);
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to, const Packet8d& from, Index stride,
uint8_t umask) {
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
__mmask8 mask = static_cast<__mmask8>(umask);
_mm512_mask_i32scatter_pd(to, mask, indices, from, 8);
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to, const Packet16f& from, Index stride) {
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
_mm512_i32scatter_ps(to, indices, from, 4);
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to, const Packet8d& from, Index stride) {
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
_mm512_i32scatter_pd(to, indices, from, 8);
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<int64_t, Packet8l>(int64_t* to, const Packet8l& from, Index stride) {
Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
_mm512_i32scatter_epi64(to, indices, from, 8);
}
template <>
EIGEN_DEVICE_FUNC inline void pscatter<int, Packet16i>(int* to, const Packet16i& from, Index stride) {
Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
_mm512_i32scatter_epi32(to, indices, from, 4);
}
template <>
EIGEN_STRONG_INLINE void pstore1<Packet16f>(float* to, const float& a) {
Packet16f pa = pset1<Packet16f>(a);
pstore(to, pa);
}
template <>
EIGEN_STRONG_INLINE void pstore1<Packet8d>(double* to, const double& a) {
Packet8d pa = pset1<Packet8d>(a);
pstore(to, pa);
}
template <>
EIGEN_STRONG_INLINE void pstore1<Packet16i>(int* to, const int& a) {
Packet16i pa = pset1<Packet16i>(a);
pstore(to, pa);
}
template <>
EIGEN_STRONG_INLINE void pstore1<Packet8l>(int64_t* to, const int64_t& a) {
Packet8l pa = pset1<Packet8l>(a);
pstore(to, pa);
}
template <>
EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) {
_mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0);
}
template <>
EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) {
_mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0);
}
template <>
EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) {
_mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0);
}
template <>
EIGEN_STRONG_INLINE float pfirst<Packet16f>(const Packet16f& a) {
return _mm512_cvtss_f32(a);
}
template <>
EIGEN_STRONG_INLINE double pfirst<Packet8d>(const Packet8d& a) {
return _mm512_cvtsd_f64(a);
}
template <>
EIGEN_STRONG_INLINE int64_t pfirst<Packet8l>(const Packet8l& a) {
int64_t x = _mm_extract_epi64_0(_mm512_extracti32x4_epi32(a, 0));
return x;
}
template <>
EIGEN_STRONG_INLINE int pfirst<Packet16i>(const Packet16i& a) {
#if EIGEN_GNUC_STRICT_LESS_THAN(11, 0, 0)
return _mm_cvtsi128_si32(_mm512_castsi512_si128(a));
#else
return _mm512_cvtsi512_si32(a);
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16f preverse(const Packet16f& a) {
return _mm512_permutexvar_ps(_mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), a);
}
template <>
EIGEN_STRONG_INLINE Packet8d preverse(const Packet8d& a) {
return _mm512_permutexvar_pd(_mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7), a);
}
template <>
EIGEN_STRONG_INLINE Packet16i preverse(const Packet16i& a) {
return _mm512_permutexvar_epi32(_mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), a);
}
template <>
EIGEN_STRONG_INLINE Packet8l preverse(const Packet8l& a) {
return _mm512_permutexvar_epi64(_mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7), a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pabs(const Packet16f& a) {
// _mm512_abs_ps intrinsic not found, so hack around it
return _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(a), _mm512_set1_epi32(0x7fffffff)));
}
template <>
EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
// _mm512_abs_ps intrinsic not found, so hack around it
return _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(a), _mm512_set1_epi64(0x7fffffffffffffff)));
}
template <>
EIGEN_STRONG_INLINE Packet16i pabs(const Packet16i& a) {
return _mm512_abs_epi32(a);
}
template <>
EIGEN_STRONG_INLINE Packet8l pabs(const Packet8l& a) {
return _mm512_abs_epi64(a);
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) {
return _mm256_srai_epi16(a, 15);
}
#endif // EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16bf psignbit(const Packet16bf& a) {
return _mm256_srai_epi16(a, 15);
}
template <>
EIGEN_STRONG_INLINE Packet16f psignbit(const Packet16f& a) {
return _mm512_castsi512_ps(_mm512_srai_epi32(_mm512_castps_si512(a), 31));
}
template <>
EIGEN_STRONG_INLINE Packet8d psignbit(const Packet8d& a) {
return _mm512_castsi512_pd(_mm512_srai_epi64(_mm512_castpd_si512(a), 63));
}
template <>
EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent) {
return pfrexp_generic(a, exponent);
}
// Extract exponent without existence of Packet8l.
template <>
EIGEN_STRONG_INLINE Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) {
const Packet8d cst_exp_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(0x7ff0000000000000ull));
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52));
#else
return _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52)));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent) {
return pfrexp_generic(a, exponent);
}
template <>
EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) {
return pldexp_generic(a, exponent);
}
template <>
EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) {
// Clamp exponent to [-2099, 2099]
const Packet8d max_exponent = pset1<Packet8d>(2099.0);
const Packet8i e = _mm512_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
// Split 2^e into four factors and multiply.
const Packet8i bias = pset1<Packet8i>(1023);
Packet8i b = parithmetic_shift_right<2>(e); // floor(e/4)
// 2^b
const Packet8i permute_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
Packet8i hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
Packet8i lo = _mm256_slli_epi64(hi, 52);
hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52);
Packet8d c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
Packet8d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
// 2^(e - 3b)
b = psub(psub(psub(e, b), b), b); // e - 3b
hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
lo = _mm256_slli_epi64(hi, 52);
hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52);
c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
out = pmul(out, c); // a * 2^e
return out;
}
#ifdef EIGEN_VECTORIZE_AVX512DQ
// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
__m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \
__m256 OUTPUT##_1 = _mm512_extractf32x8_ps(INPUT, 1)
// AVX512F does not define _mm512_extracti32x8_epi32 to extract _m256i from _m512i
#define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \
__m256i OUTPUT##_0 = _mm512_extracti32x8_epi32(INPUT, 0); \
__m256i OUTPUT##_1 = _mm512_extracti32x8_epi32(INPUT, 1)
#else
#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
__m256 OUTPUT##_0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 0)), \
_mm512_extractf32x4_ps(INPUT, 1), 1); \
__m256 OUTPUT##_1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 2)), \
_mm512_extractf32x4_ps(INPUT, 3), 1)
#define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \
__m256i OUTPUT##_0 = _mm256_insertf128_si256(_mm256_castsi128_si256(_mm512_extracti32x4_epi32(INPUT, 0)), \
_mm512_extracti32x4_epi32(INPUT, 1), 1); \
__m256i OUTPUT##_1 = _mm256_insertf128_si256(_mm256_castsi128_si256(_mm512_extracti32x4_epi32(INPUT, 2)), \
_mm512_extracti32x4_epi32(INPUT, 3), 1)
#endif
#ifdef EIGEN_VECTORIZE_AVX512DQ
#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
OUTPUT = _mm512_insertf32x8(_mm512_castps256_ps512(INPUTA), INPUTB, 1);
#define EIGEN_INSERT_8i_INTO_16i(OUTPUT, INPUTA, INPUTB) \
OUTPUT = _mm512_inserti32x8(_mm512_castsi256_si512(INPUTA), INPUTB, 1);
#else
#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
OUTPUT = _mm512_undefined_ps(); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 0), 0); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \
OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3);
#define EIGEN_INSERT_8i_INTO_16i(OUTPUT, INPUTA, INPUTB) \
OUTPUT = _mm512_undefined_epi32(); \
OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTA, 0), 0); \
OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTA, 1), 1); \
OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTB, 0), 2); \
OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTB, 1), 3);
#endif
template <>
EIGEN_STRONG_INLINE float predux<Packet16f>(const Packet16f& a) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
__m256 lane0 = _mm512_extractf32x8_ps(a, 0);
__m256 lane1 = _mm512_extractf32x8_ps(a, 1);
Packet8f x = _mm256_add_ps(lane0, lane1);
return predux<Packet8f>(x);
#else
__m128 lane0 = _mm512_extractf32x4_ps(a, 0);
__m128 lane1 = _mm512_extractf32x4_ps(a, 1);
__m128 lane2 = _mm512_extractf32x4_ps(a, 2);
__m128 lane3 = _mm512_extractf32x4_ps(a, 3);
__m128 sum = _mm_add_ps(_mm_add_ps(lane0, lane1), _mm_add_ps(lane2, lane3));
return predux<Packet4f>(sum);
#endif
}
template <>
EIGEN_STRONG_INLINE double predux<Packet8d>(const Packet8d& a) {
__m256d lane0 = _mm512_extractf64x4_pd(a, 0);
__m256d lane1 = _mm512_extractf64x4_pd(a, 1);
__m256d sum = _mm256_add_pd(lane0, lane1);
return predux<Packet4d>(sum);
}
template <>
EIGEN_STRONG_INLINE int64_t predux<Packet8l>(const Packet8l& a) {
return _mm512_reduce_add_epi64(a);
}
template <>
EIGEN_STRONG_INLINE int predux<Packet16i>(const Packet16i& a) {
return _mm512_reduce_add_epi32(a);
}
template <>
EIGEN_STRONG_INLINE Packet8f predux_half_dowto4<Packet16f>(const Packet16f& a) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
__m256 lane0 = _mm512_extractf32x8_ps(a, 0);
__m256 lane1 = _mm512_extractf32x8_ps(a, 1);
return _mm256_add_ps(lane0, lane1);
#else
__m128 lane0 = _mm512_extractf32x4_ps(a, 0);
__m128 lane1 = _mm512_extractf32x4_ps(a, 1);
__m128 lane2 = _mm512_extractf32x4_ps(a, 2);
__m128 lane3 = _mm512_extractf32x4_ps(a, 3);
__m128 sum0 = _mm_add_ps(lane0, lane2);
__m128 sum1 = _mm_add_ps(lane1, lane3);
return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1);
#endif
}
template <>
EIGEN_STRONG_INLINE Packet4d predux_half_dowto4<Packet8d>(const Packet8d& a) {
__m256d lane0 = _mm512_extractf64x4_pd(a, 0);
__m256d lane1 = _mm512_extractf64x4_pd(a, 1);
return _mm256_add_pd(lane0, lane1);
}
template <>
EIGEN_STRONG_INLINE Packet8i predux_half_dowto4<Packet16i>(const Packet16i& a) {
#ifdef EIGEN_VECTORIZE_AVX512DQ
__m256i lane0 = _mm512_extracti32x8_epi32(a, 0);
__m256i lane1 = _mm512_extracti32x8_epi32(a, 1);
return _mm256_add_epi32(lane0, lane1);
#else
__m128i lane0 = _mm512_extracti32x4_epi32(a, 0);
__m128i lane1 = _mm512_extracti32x4_epi32(a, 1);
__m128i lane2 = _mm512_extracti32x4_epi32(a, 2);
__m128i lane3 = _mm512_extracti32x4_epi32(a, 3);
__m128i sum0 = _mm_add_epi32(lane0, lane2);
__m128i sum1 = _mm_add_epi32(lane1, lane3);
return _mm256_inserti128_si256(_mm256_castsi128_si256(sum0), sum1, 1);
#endif
}
template <>
EIGEN_STRONG_INLINE Packet4l predux_half_dowto4<Packet8l>(const Packet8l& a) {
__m256i lane0 = _mm512_extracti64x4_epi64(a, 0);
__m256i lane1 = _mm512_extracti64x4_epi64(a, 1);
return _mm256_add_epi64(lane0, lane1);
}
template <>
EIGEN_STRONG_INLINE float predux_mul<Packet16f>(const Packet16f& a) {
// #ifdef EIGEN_VECTORIZE_AVX512DQ
#if 0
Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
Packet8f res = pmul(lane0, lane1);
res = pmul(res, _mm256_permute2f128_ps(res, res, 1));
res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
#else
__m128 lane0 = _mm512_extractf32x4_ps(a, 0);
__m128 lane1 = _mm512_extractf32x4_ps(a, 1);
__m128 lane2 = _mm512_extractf32x4_ps(a, 2);
__m128 lane3 = _mm512_extractf32x4_ps(a, 3);
__m128 res = pmul(pmul(lane0, lane1), pmul(lane2, lane3));
res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
#endif
}
template <>
EIGEN_STRONG_INLINE double predux_mul<Packet8d>(const Packet8d& a) {
__m256d lane0 = _mm512_extractf64x4_pd(a, 0);
__m256d lane1 = _mm512_extractf64x4_pd(a, 1);
__m256d res = pmul(lane0, lane1);
res = pmul(res, _mm256_permute2f128_pd(res, res, 1));
return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1)));
}
template <>
EIGEN_STRONG_INLINE int predux_mul<Packet16i>(const Packet16i& a) {
return _mm512_reduce_mul_epi32(a);
}
#if EIGEN_COMP_MSVC
// MSVC's _mm512_reduce_mul_epi64 is borked, at least up to and including 1939.
// alignas(64) int64_t data[] = { 1,1,-1,-1,1,-1,-1,-1 };
// int64_t out = _mm512_reduce_mul_epi64(_mm512_load_epi64(data));
// produces garbage: 4294967295. It seems to happen whenever the output is supposed to be negative.
// Fall back to a manual approach:
template <>
EIGEN_STRONG_INLINE int64_t predux_mul<Packet8l>(const Packet8l& a) {
Packet4l lane0 = _mm512_extracti64x4_epi64(a, 0);
Packet4l lane1 = _mm512_extracti64x4_epi64(a, 1);
Packet4l res = pmul(lane0, lane1);
res = pmul(res, Packet4l(_mm256_permute2x128_si256(res, res, 1)));
res = pmul(res, Packet4l(_mm256_shuffle_epi32(res, 0xE)));
return pfirst(res);
}
#else
template <>
EIGEN_STRONG_INLINE int64_t predux_mul<Packet8l>(const Packet8l& a) {
return _mm512_reduce_mul_epi64(a);
}
#endif
template <>
EIGEN_STRONG_INLINE float predux_min<Packet16f>(const Packet16f& a) {
__m128 lane0 = _mm512_extractf32x4_ps(a, 0);
__m128 lane1 = _mm512_extractf32x4_ps(a, 1);
__m128 lane2 = _mm512_extractf32x4_ps(a, 2);
__m128 lane3 = _mm512_extractf32x4_ps(a, 3);
__m128 res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3));
res = _mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(_mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
}
template <>
EIGEN_STRONG_INLINE double predux_min<Packet8d>(const Packet8d& a) {
__m256d lane0 = _mm512_extractf64x4_pd(a, 0);
__m256d lane1 = _mm512_extractf64x4_pd(a, 1);
__m256d res = _mm256_min_pd(lane0, lane1);
res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1));
return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1)));
}
template <>
EIGEN_STRONG_INLINE int predux_min<Packet16i>(const Packet16i& a) {
return _mm512_reduce_min_epi32(a);
}
template <>
EIGEN_STRONG_INLINE int64_t predux_min<Packet8l>(const Packet8l& a) {
return _mm512_reduce_min_epi64(a);
}
template <>
EIGEN_STRONG_INLINE float predux_max<Packet16f>(const Packet16f& a) {
__m128 lane0 = _mm512_extractf32x4_ps(a, 0);
__m128 lane1 = _mm512_extractf32x4_ps(a, 1);
__m128 lane2 = _mm512_extractf32x4_ps(a, 2);
__m128 lane3 = _mm512_extractf32x4_ps(a, 3);
__m128 res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3));
res = _mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(_mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
}
template <>
EIGEN_STRONG_INLINE double predux_max<Packet8d>(const Packet8d& a) {
__m256d lane0 = _mm512_extractf64x4_pd(a, 0);
__m256d lane1 = _mm512_extractf64x4_pd(a, 1);
__m256d res = _mm256_max_pd(lane0, lane1);
res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1));
return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1)));
}
template <>
EIGEN_STRONG_INLINE int predux_max<Packet16i>(const Packet16i& a) {
return _mm512_reduce_max_epi32(a);
}
template <>
EIGEN_STRONG_INLINE int64_t predux_max<Packet8l>(const Packet8l& a) {
return _mm512_reduce_max_epi64(a);
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet16f& a) {
return _mm512_reduce_or_epi32(_mm512_castps_si512(a)) != 0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet16i& a) {
return _mm512_reduce_or_epi32(a) != 0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8d& a) {
return _mm512_reduce_or_epi64(_mm512_castpd_si512(a)) != 0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8l& a) {
return _mm512_reduce_or_epi64(a) != 0;
}
#define PACK_OUTPUT(OUTPUT, INPUT, INDEX, STRIDE) \
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]);
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 16>& kernel) {
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
__m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
__m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
__m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
__m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
__m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
__m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
__m512 T8 = _mm512_unpacklo_ps(kernel.packet[8], kernel.packet[9]);
__m512 T9 = _mm512_unpackhi_ps(kernel.packet[8], kernel.packet[9]);
__m512 T10 = _mm512_unpacklo_ps(kernel.packet[10], kernel.packet[11]);
__m512 T11 = _mm512_unpackhi_ps(kernel.packet[10], kernel.packet[11]);
__m512 T12 = _mm512_unpacklo_ps(kernel.packet[12], kernel.packet[13]);
__m512 T13 = _mm512_unpackhi_ps(kernel.packet[12], kernel.packet[13]);
__m512 T14 = _mm512_unpacklo_ps(kernel.packet[14], kernel.packet[15]);
__m512 T15 = _mm512_unpackhi_ps(kernel.packet[14], kernel.packet[15]);
__m512 S0 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S1 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(3, 2, 3, 2));
__m512 S2 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S3 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(3, 2, 3, 2));
__m512 S4 = _mm512_shuffle_ps(T4, T6, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S5 = _mm512_shuffle_ps(T4, T6, _MM_SHUFFLE(3, 2, 3, 2));
__m512 S6 = _mm512_shuffle_ps(T5, T7, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S7 = _mm512_shuffle_ps(T5, T7, _MM_SHUFFLE(3, 2, 3, 2));
__m512 S8 = _mm512_shuffle_ps(T8, T10, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S9 = _mm512_shuffle_ps(T8, T10, _MM_SHUFFLE(3, 2, 3, 2));
__m512 S10 = _mm512_shuffle_ps(T9, T11, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S11 = _mm512_shuffle_ps(T9, T11, _MM_SHUFFLE(3, 2, 3, 2));
__m512 S12 = _mm512_shuffle_ps(T12, T14, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S13 = _mm512_shuffle_ps(T12, T14, _MM_SHUFFLE(3, 2, 3, 2));
__m512 S14 = _mm512_shuffle_ps(T13, T15, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S15 = _mm512_shuffle_ps(T13, T15, _MM_SHUFFLE(3, 2, 3, 2));
EIGEN_EXTRACT_8f_FROM_16f(S0, S0);
EIGEN_EXTRACT_8f_FROM_16f(S1, S1);
EIGEN_EXTRACT_8f_FROM_16f(S2, S2);
EIGEN_EXTRACT_8f_FROM_16f(S3, S3);
EIGEN_EXTRACT_8f_FROM_16f(S4, S4);
EIGEN_EXTRACT_8f_FROM_16f(S5, S5);
EIGEN_EXTRACT_8f_FROM_16f(S6, S6);
EIGEN_EXTRACT_8f_FROM_16f(S7, S7);
EIGEN_EXTRACT_8f_FROM_16f(S8, S8);
EIGEN_EXTRACT_8f_FROM_16f(S9, S9);
EIGEN_EXTRACT_8f_FROM_16f(S10, S10);
EIGEN_EXTRACT_8f_FROM_16f(S11, S11);
EIGEN_EXTRACT_8f_FROM_16f(S12, S12);
EIGEN_EXTRACT_8f_FROM_16f(S13, S13);
EIGEN_EXTRACT_8f_FROM_16f(S14, S14);
EIGEN_EXTRACT_8f_FROM_16f(S15, S15);
PacketBlock<Packet8f, 32> tmp;
tmp.packet[0] = _mm256_permute2f128_ps(S0_0, S4_0, 0x20);
tmp.packet[1] = _mm256_permute2f128_ps(S1_0, S5_0, 0x20);
tmp.packet[2] = _mm256_permute2f128_ps(S2_0, S6_0, 0x20);
tmp.packet[3] = _mm256_permute2f128_ps(S3_0, S7_0, 0x20);
tmp.packet[4] = _mm256_permute2f128_ps(S0_0, S4_0, 0x31);
tmp.packet[5] = _mm256_permute2f128_ps(S1_0, S5_0, 0x31);
tmp.packet[6] = _mm256_permute2f128_ps(S2_0, S6_0, 0x31);
tmp.packet[7] = _mm256_permute2f128_ps(S3_0, S7_0, 0x31);
tmp.packet[8] = _mm256_permute2f128_ps(S0_1, S4_1, 0x20);
tmp.packet[9] = _mm256_permute2f128_ps(S1_1, S5_1, 0x20);
tmp.packet[10] = _mm256_permute2f128_ps(S2_1, S6_1, 0x20);
tmp.packet[11] = _mm256_permute2f128_ps(S3_1, S7_1, 0x20);
tmp.packet[12] = _mm256_permute2f128_ps(S0_1, S4_1, 0x31);
tmp.packet[13] = _mm256_permute2f128_ps(S1_1, S5_1, 0x31);
tmp.packet[14] = _mm256_permute2f128_ps(S2_1, S6_1, 0x31);
tmp.packet[15] = _mm256_permute2f128_ps(S3_1, S7_1, 0x31);
// Second set of _m256 outputs
tmp.packet[16] = _mm256_permute2f128_ps(S8_0, S12_0, 0x20);
tmp.packet[17] = _mm256_permute2f128_ps(S9_0, S13_0, 0x20);
tmp.packet[18] = _mm256_permute2f128_ps(S10_0, S14_0, 0x20);
tmp.packet[19] = _mm256_permute2f128_ps(S11_0, S15_0, 0x20);
tmp.packet[20] = _mm256_permute2f128_ps(S8_0, S12_0, 0x31);
tmp.packet[21] = _mm256_permute2f128_ps(S9_0, S13_0, 0x31);
tmp.packet[22] = _mm256_permute2f128_ps(S10_0, S14_0, 0x31);
tmp.packet[23] = _mm256_permute2f128_ps(S11_0, S15_0, 0x31);
tmp.packet[24] = _mm256_permute2f128_ps(S8_1, S12_1, 0x20);
tmp.packet[25] = _mm256_permute2f128_ps(S9_1, S13_1, 0x20);
tmp.packet[26] = _mm256_permute2f128_ps(S10_1, S14_1, 0x20);
tmp.packet[27] = _mm256_permute2f128_ps(S11_1, S15_1, 0x20);
tmp.packet[28] = _mm256_permute2f128_ps(S8_1, S12_1, 0x31);
tmp.packet[29] = _mm256_permute2f128_ps(S9_1, S13_1, 0x31);
tmp.packet[30] = _mm256_permute2f128_ps(S10_1, S14_1, 0x31);
tmp.packet[31] = _mm256_permute2f128_ps(S11_1, S15_1, 0x31);
// Pack them into the output
PACK_OUTPUT(kernel.packet, tmp.packet, 0, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 1, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 2, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 3, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 4, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 5, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 6, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 7, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 8, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 9, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 10, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 11, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 12, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 13, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 14, 16);
PACK_OUTPUT(kernel.packet, tmp.packet, 15, 16);
}
#define PACK_OUTPUT_2(OUTPUT, INPUT, INDEX, STRIDE) \
EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], INPUT[2 * INDEX + STRIDE]);
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
__m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
__m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
__m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
__m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
__m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
__m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
T0 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0x44);
T1 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0xee);
T2 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0x44);
T3 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0xee);
T4 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0x44);
T5 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0xee);
T6 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0x44);
T7 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0xee);
kernel.packet[0] = _mm512_shuffle_f32x4(T0, T2, 0x88);
kernel.packet[2] = _mm512_shuffle_f32x4(T0, T2, 0xdd);
kernel.packet[1] = _mm512_shuffle_f32x4(T4, T6, 0x88);
kernel.packet[3] = _mm512_shuffle_f32x4(T4, T6, 0xdd);
kernel.packet[4] = _mm512_shuffle_f32x4(T1, T3, 0x88);
kernel.packet[6] = _mm512_shuffle_f32x4(T1, T3, 0xdd);
kernel.packet[5] = _mm512_shuffle_f32x4(T5, T7, 0x88);
kernel.packet[7] = _mm512_shuffle_f32x4(T5, T7, 0xdd);
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
__m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
__m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
__m512 S0 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S1 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(3, 2, 3, 2));
__m512 S2 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(1, 0, 1, 0));
__m512 S3 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(3, 2, 3, 2));
EIGEN_EXTRACT_8f_FROM_16f(S0, S0);
EIGEN_EXTRACT_8f_FROM_16f(S1, S1);
EIGEN_EXTRACT_8f_FROM_16f(S2, S2);
EIGEN_EXTRACT_8f_FROM_16f(S3, S3);
PacketBlock<Packet8f, 8> tmp;
tmp.packet[0] = _mm256_permute2f128_ps(S0_0, S1_0, 0x20);
tmp.packet[1] = _mm256_permute2f128_ps(S2_0, S3_0, 0x20);
tmp.packet[2] = _mm256_permute2f128_ps(S0_0, S1_0, 0x31);
tmp.packet[3] = _mm256_permute2f128_ps(S2_0, S3_0, 0x31);
tmp.packet[4] = _mm256_permute2f128_ps(S0_1, S1_1, 0x20);
tmp.packet[5] = _mm256_permute2f128_ps(S2_1, S3_1, 0x20);
tmp.packet[6] = _mm256_permute2f128_ps(S0_1, S1_1, 0x31);
tmp.packet[7] = _mm256_permute2f128_ps(S2_1, S3_1, 0x31);
PACK_OUTPUT_2(kernel.packet, tmp.packet, 0, 1);
PACK_OUTPUT_2(kernel.packet, tmp.packet, 1, 1);
PACK_OUTPUT_2(kernel.packet, tmp.packet, 2, 1);
PACK_OUTPUT_2(kernel.packet, tmp.packet, 3, 1);
}
#define PACK_OUTPUT_SQ_D(OUTPUT, INPUT, INDEX, STRIDE) \
OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[INDEX], 0); \
OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[INDEX + STRIDE], 1);
#define PACK_OUTPUT_D(OUTPUT, INPUT, INDEX, STRIDE) \
OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \
OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1);
#define PACK_OUTPUT_L(OUTPUT, INPUT, INDEX, STRIDE) \
OUTPUT[INDEX] = _mm512_inserti64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \
OUTPUT[INDEX] = _mm512_inserti64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1);
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 4>& kernel) {
__m512d T0 = _mm512_shuffle_pd(kernel.packet[0], kernel.packet[1], 0);
__m512d T1 = _mm512_shuffle_pd(kernel.packet[0], kernel.packet[1], 0xff);
__m512d T2 = _mm512_shuffle_pd(kernel.packet[2], kernel.packet[3], 0);
__m512d T3 = _mm512_shuffle_pd(kernel.packet[2], kernel.packet[3], 0xff);
PacketBlock<Packet4d, 8> tmp;
tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), _mm512_extractf64x4_pd(T2, 0), 0x20);
tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), _mm512_extractf64x4_pd(T3, 0), 0x20);
tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), _mm512_extractf64x4_pd(T2, 0), 0x31);
tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), _mm512_extractf64x4_pd(T3, 0), 0x31);
tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), _mm512_extractf64x4_pd(T2, 1), 0x20);
tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), _mm512_extractf64x4_pd(T3, 1), 0x20);
tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), _mm512_extractf64x4_pd(T2, 1), 0x31);
tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), _mm512_extractf64x4_pd(T3, 1), 0x31);
PACK_OUTPUT_D(kernel.packet, tmp.packet, 0, 1);
PACK_OUTPUT_D(kernel.packet, tmp.packet, 1, 1);
PACK_OUTPUT_D(kernel.packet, tmp.packet, 2, 1);
PACK_OUTPUT_D(kernel.packet, tmp.packet, 3, 1);
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 8>& kernel) {
__m512d T0 = _mm512_unpacklo_pd(kernel.packet[0], kernel.packet[1]);
__m512d T1 = _mm512_unpackhi_pd(kernel.packet[0], kernel.packet[1]);
__m512d T2 = _mm512_unpacklo_pd(kernel.packet[2], kernel.packet[3]);
__m512d T3 = _mm512_unpackhi_pd(kernel.packet[2], kernel.packet[3]);
__m512d T4 = _mm512_unpacklo_pd(kernel.packet[4], kernel.packet[5]);
__m512d T5 = _mm512_unpackhi_pd(kernel.packet[4], kernel.packet[5]);
__m512d T6 = _mm512_unpacklo_pd(kernel.packet[6], kernel.packet[7]);
__m512d T7 = _mm512_unpackhi_pd(kernel.packet[6], kernel.packet[7]);
kernel.packet[0] = _mm512_permutex_pd(T2, 0x4E);
kernel.packet[0] = _mm512_mask_blend_pd(0xCC, T0, kernel.packet[0]);
kernel.packet[2] = _mm512_permutex_pd(T0, 0x4E);
kernel.packet[2] = _mm512_mask_blend_pd(0xCC, kernel.packet[2], T2);
kernel.packet[1] = _mm512_permutex_pd(T3, 0x4E);
kernel.packet[1] = _mm512_mask_blend_pd(0xCC, T1, kernel.packet[1]);
kernel.packet[3] = _mm512_permutex_pd(T1, 0x4E);
kernel.packet[3] = _mm512_mask_blend_pd(0xCC, kernel.packet[3], T3);
kernel.packet[4] = _mm512_permutex_pd(T6, 0x4E);
kernel.packet[4] = _mm512_mask_blend_pd(0xCC, T4, kernel.packet[4]);
kernel.packet[6] = _mm512_permutex_pd(T4, 0x4E);
kernel.packet[6] = _mm512_mask_blend_pd(0xCC, kernel.packet[6], T6);
kernel.packet[5] = _mm512_permutex_pd(T7, 0x4E);
kernel.packet[5] = _mm512_mask_blend_pd(0xCC, T5, kernel.packet[5]);
kernel.packet[7] = _mm512_permutex_pd(T5, 0x4E);
kernel.packet[7] = _mm512_mask_blend_pd(0xCC, kernel.packet[7], T7);
T0 = _mm512_shuffle_f64x2(kernel.packet[4], kernel.packet[4], 0x4E);
T0 = _mm512_mask_blend_pd(0xF0, kernel.packet[0], T0);
T4 = _mm512_shuffle_f64x2(kernel.packet[0], kernel.packet[0], 0x4E);
T4 = _mm512_mask_blend_pd(0xF0, T4, kernel.packet[4]);
T1 = _mm512_shuffle_f64x2(kernel.packet[5], kernel.packet[5], 0x4E);
T1 = _mm512_mask_blend_pd(0xF0, kernel.packet[1], T1);
T5 = _mm512_shuffle_f64x2(kernel.packet[1], kernel.packet[1], 0x4E);
T5 = _mm512_mask_blend_pd(0xF0, T5, kernel.packet[5]);
T2 = _mm512_shuffle_f64x2(kernel.packet[6], kernel.packet[6], 0x4E);
T2 = _mm512_mask_blend_pd(0xF0, kernel.packet[2], T2);
T6 = _mm512_shuffle_f64x2(kernel.packet[2], kernel.packet[2], 0x4E);
T6 = _mm512_mask_blend_pd(0xF0, T6, kernel.packet[6]);
T3 = _mm512_shuffle_f64x2(kernel.packet[7], kernel.packet[7], 0x4E);
T3 = _mm512_mask_blend_pd(0xF0, kernel.packet[3], T3);
T7 = _mm512_shuffle_f64x2(kernel.packet[3], kernel.packet[3], 0x4E);
T7 = _mm512_mask_blend_pd(0xF0, T7, kernel.packet[7]);
kernel.packet[0] = T0;
kernel.packet[1] = T1;
kernel.packet[2] = T2;
kernel.packet[3] = T3;
kernel.packet[4] = T4;
kernel.packet[5] = T5;
kernel.packet[6] = T6;
kernel.packet[7] = T7;
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8l, 4>& kernel) {
__m512i T0 = _mm512_castpd_si512(
_mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[0]), _mm512_castsi512_pd(kernel.packet[1]), 0));
__m512i T1 = _mm512_castpd_si512(
_mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[0]), _mm512_castsi512_pd(kernel.packet[1]), 0xff));
__m512i T2 = _mm512_castpd_si512(
_mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[2]), _mm512_castsi512_pd(kernel.packet[3]), 0));
__m512i T3 = _mm512_castpd_si512(
_mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[2]), _mm512_castsi512_pd(kernel.packet[3]), 0xff));
PacketBlock<Packet4l, 8> tmp;
tmp.packet[0] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 0), _mm512_extracti64x4_epi64(T2, 0), 0x20);
tmp.packet[1] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 0), _mm512_extracti64x4_epi64(T3, 0), 0x20);
tmp.packet[2] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 0), _mm512_extracti64x4_epi64(T2, 0), 0x31);
tmp.packet[3] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 0), _mm512_extracti64x4_epi64(T3, 0), 0x31);
tmp.packet[4] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 1), _mm512_extracti64x4_epi64(T2, 1), 0x20);
tmp.packet[5] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 1), _mm512_extracti64x4_epi64(T3, 1), 0x20);
tmp.packet[6] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 1), _mm512_extracti64x4_epi64(T2, 1), 0x31);
tmp.packet[7] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 1), _mm512_extracti64x4_epi64(T3, 1), 0x31);
PACK_OUTPUT_L(kernel.packet, tmp.packet, 0, 1);
PACK_OUTPUT_L(kernel.packet, tmp.packet, 1, 1);
PACK_OUTPUT_L(kernel.packet, tmp.packet, 2, 1);
PACK_OUTPUT_L(kernel.packet, tmp.packet, 3, 1);
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8l, 8>& kernel) {
__m512i T0 = _mm512_unpacklo_epi64(kernel.packet[0], kernel.packet[1]);
__m512i T1 = _mm512_unpackhi_epi64(kernel.packet[0], kernel.packet[1]);
__m512i T2 = _mm512_unpacklo_epi64(kernel.packet[2], kernel.packet[3]);
__m512i T3 = _mm512_unpackhi_epi64(kernel.packet[2], kernel.packet[3]);
__m512i T4 = _mm512_unpacklo_epi64(kernel.packet[4], kernel.packet[5]);
__m512i T5 = _mm512_unpackhi_epi64(kernel.packet[4], kernel.packet[5]);
__m512i T6 = _mm512_unpacklo_epi64(kernel.packet[6], kernel.packet[7]);
__m512i T7 = _mm512_unpackhi_epi64(kernel.packet[6], kernel.packet[7]);
kernel.packet[0] = _mm512_permutex_epi64(T2, 0x4E);
kernel.packet[0] = _mm512_mask_blend_epi64(0xCC, T0, kernel.packet[0]);
kernel.packet[2] = _mm512_permutex_epi64(T0, 0x4E);
kernel.packet[2] = _mm512_mask_blend_epi64(0xCC, kernel.packet[2], T2);
kernel.packet[1] = _mm512_permutex_epi64(T3, 0x4E);
kernel.packet[1] = _mm512_mask_blend_epi64(0xCC, T1, kernel.packet[1]);
kernel.packet[3] = _mm512_permutex_epi64(T1, 0x4E);
kernel.packet[3] = _mm512_mask_blend_epi64(0xCC, kernel.packet[3], T3);
kernel.packet[4] = _mm512_permutex_epi64(T6, 0x4E);
kernel.packet[4] = _mm512_mask_blend_epi64(0xCC, T4, kernel.packet[4]);
kernel.packet[6] = _mm512_permutex_epi64(T4, 0x4E);
kernel.packet[6] = _mm512_mask_blend_epi64(0xCC, kernel.packet[6], T6);
kernel.packet[5] = _mm512_permutex_epi64(T7, 0x4E);
kernel.packet[5] = _mm512_mask_blend_epi64(0xCC, T5, kernel.packet[5]);
kernel.packet[7] = _mm512_permutex_epi64(T5, 0x4E);
kernel.packet[7] = _mm512_mask_blend_epi64(0xCC, kernel.packet[7], T7);
T0 = _mm512_shuffle_i64x2(kernel.packet[4], kernel.packet[4], 0x4E);
T0 = _mm512_mask_blend_epi64(0xF0, kernel.packet[0], T0);
T4 = _mm512_shuffle_i64x2(kernel.packet[0], kernel.packet[0], 0x4E);
T4 = _mm512_mask_blend_epi64(0xF0, T4, kernel.packet[4]);
T1 = _mm512_shuffle_i64x2(kernel.packet[5], kernel.packet[5], 0x4E);
T1 = _mm512_mask_blend_epi64(0xF0, kernel.packet[1], T1);
T5 = _mm512_shuffle_i64x2(kernel.packet[1], kernel.packet[1], 0x4E);
T5 = _mm512_mask_blend_epi64(0xF0, T5, kernel.packet[5]);
T2 = _mm512_shuffle_i64x2(kernel.packet[6], kernel.packet[6], 0x4E);
T2 = _mm512_mask_blend_epi64(0xF0, kernel.packet[2], T2);
T6 = _mm512_shuffle_i64x2(kernel.packet[2], kernel.packet[2], 0x4E);
T6 = _mm512_mask_blend_epi64(0xF0, T6, kernel.packet[6]);
T3 = _mm512_shuffle_i64x2(kernel.packet[7], kernel.packet[7], 0x4E);
T3 = _mm512_mask_blend_epi64(0xF0, kernel.packet[3], T3);
T7 = _mm512_shuffle_i64x2(kernel.packet[3], kernel.packet[3], 0x4E);
T7 = _mm512_mask_blend_epi64(0xF0, T7, kernel.packet[7]);
kernel.packet[0] = T0;
kernel.packet[1] = T1;
kernel.packet[2] = T2;
kernel.packet[3] = T3;
kernel.packet[4] = T4;
kernel.packet[5] = T5;
kernel.packet[6] = T6;
kernel.packet[7] = T7;
}
#define PACK_OUTPUT_I32(OUTPUT, INPUT, INDEX, STRIDE) \
EIGEN_INSERT_8i_INTO_16i(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]);
#define PACK_OUTPUT_I32_2(OUTPUT, INPUT, INDEX, STRIDE) \
EIGEN_INSERT_8i_INTO_16i(OUTPUT[INDEX], INPUT[2 * INDEX], INPUT[2 * INDEX + STRIDE]);
#define SHUFFLE_EPI32(A, B, M) _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(A), _mm512_castsi512_ps(B), M))
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16i, 16>& kernel) {
__m512i T0 = _mm512_unpacklo_epi32(kernel.packet[0], kernel.packet[1]);
__m512i T1 = _mm512_unpackhi_epi32(kernel.packet[0], kernel.packet[1]);
__m512i T2 = _mm512_unpacklo_epi32(kernel.packet[2], kernel.packet[3]);
__m512i T3 = _mm512_unpackhi_epi32(kernel.packet[2], kernel.packet[3]);
__m512i T4 = _mm512_unpacklo_epi32(kernel.packet[4], kernel.packet[5]);
__m512i T5 = _mm512_unpackhi_epi32(kernel.packet[4], kernel.packet[5]);
__m512i T6 = _mm512_unpacklo_epi32(kernel.packet[6], kernel.packet[7]);
__m512i T7 = _mm512_unpackhi_epi32(kernel.packet[6], kernel.packet[7]);
__m512i T8 = _mm512_unpacklo_epi32(kernel.packet[8], kernel.packet[9]);
__m512i T9 = _mm512_unpackhi_epi32(kernel.packet[8], kernel.packet[9]);
__m512i T10 = _mm512_unpacklo_epi32(kernel.packet[10], kernel.packet[11]);
__m512i T11 = _mm512_unpackhi_epi32(kernel.packet[10], kernel.packet[11]);
__m512i T12 = _mm512_unpacklo_epi32(kernel.packet[12], kernel.packet[13]);
__m512i T13 = _mm512_unpackhi_epi32(kernel.packet[12], kernel.packet[13]);
__m512i T14 = _mm512_unpacklo_epi32(kernel.packet[14], kernel.packet[15]);
__m512i T15 = _mm512_unpackhi_epi32(kernel.packet[14], kernel.packet[15]);
__m512i S0 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S1 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(3, 2, 3, 2));
__m512i S2 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S3 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(3, 2, 3, 2));
__m512i S4 = SHUFFLE_EPI32(T4, T6, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S5 = SHUFFLE_EPI32(T4, T6, _MM_SHUFFLE(3, 2, 3, 2));
__m512i S6 = SHUFFLE_EPI32(T5, T7, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S7 = SHUFFLE_EPI32(T5, T7, _MM_SHUFFLE(3, 2, 3, 2));
__m512i S8 = SHUFFLE_EPI32(T8, T10, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S9 = SHUFFLE_EPI32(T8, T10, _MM_SHUFFLE(3, 2, 3, 2));
__m512i S10 = SHUFFLE_EPI32(T9, T11, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S11 = SHUFFLE_EPI32(T9, T11, _MM_SHUFFLE(3, 2, 3, 2));
__m512i S12 = SHUFFLE_EPI32(T12, T14, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S13 = SHUFFLE_EPI32(T12, T14, _MM_SHUFFLE(3, 2, 3, 2));
__m512i S14 = SHUFFLE_EPI32(T13, T15, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S15 = SHUFFLE_EPI32(T13, T15, _MM_SHUFFLE(3, 2, 3, 2));
EIGEN_EXTRACT_8i_FROM_16i(S0, S0);
EIGEN_EXTRACT_8i_FROM_16i(S1, S1);
EIGEN_EXTRACT_8i_FROM_16i(S2, S2);
EIGEN_EXTRACT_8i_FROM_16i(S3, S3);
EIGEN_EXTRACT_8i_FROM_16i(S4, S4);
EIGEN_EXTRACT_8i_FROM_16i(S5, S5);
EIGEN_EXTRACT_8i_FROM_16i(S6, S6);
EIGEN_EXTRACT_8i_FROM_16i(S7, S7);
EIGEN_EXTRACT_8i_FROM_16i(S8, S8);
EIGEN_EXTRACT_8i_FROM_16i(S9, S9);
EIGEN_EXTRACT_8i_FROM_16i(S10, S10);
EIGEN_EXTRACT_8i_FROM_16i(S11, S11);
EIGEN_EXTRACT_8i_FROM_16i(S12, S12);
EIGEN_EXTRACT_8i_FROM_16i(S13, S13);
EIGEN_EXTRACT_8i_FROM_16i(S14, S14);
EIGEN_EXTRACT_8i_FROM_16i(S15, S15);
PacketBlock<Packet8i, 32> tmp;
tmp.packet[0] = _mm256_permute2f128_si256(S0_0, S4_0, 0x20);
tmp.packet[1] = _mm256_permute2f128_si256(S1_0, S5_0, 0x20);
tmp.packet[2] = _mm256_permute2f128_si256(S2_0, S6_0, 0x20);
tmp.packet[3] = _mm256_permute2f128_si256(S3_0, S7_0, 0x20);
tmp.packet[4] = _mm256_permute2f128_si256(S0_0, S4_0, 0x31);
tmp.packet[5] = _mm256_permute2f128_si256(S1_0, S5_0, 0x31);
tmp.packet[6] = _mm256_permute2f128_si256(S2_0, S6_0, 0x31);
tmp.packet[7] = _mm256_permute2f128_si256(S3_0, S7_0, 0x31);
tmp.packet[8] = _mm256_permute2f128_si256(S0_1, S4_1, 0x20);
tmp.packet[9] = _mm256_permute2f128_si256(S1_1, S5_1, 0x20);
tmp.packet[10] = _mm256_permute2f128_si256(S2_1, S6_1, 0x20);
tmp.packet[11] = _mm256_permute2f128_si256(S3_1, S7_1, 0x20);
tmp.packet[12] = _mm256_permute2f128_si256(S0_1, S4_1, 0x31);
tmp.packet[13] = _mm256_permute2f128_si256(S1_1, S5_1, 0x31);
tmp.packet[14] = _mm256_permute2f128_si256(S2_1, S6_1, 0x31);
tmp.packet[15] = _mm256_permute2f128_si256(S3_1, S7_1, 0x31);
// Second set of _m256 outputs
tmp.packet[16] = _mm256_permute2f128_si256(S8_0, S12_0, 0x20);
tmp.packet[17] = _mm256_permute2f128_si256(S9_0, S13_0, 0x20);
tmp.packet[18] = _mm256_permute2f128_si256(S10_0, S14_0, 0x20);
tmp.packet[19] = _mm256_permute2f128_si256(S11_0, S15_0, 0x20);
tmp.packet[20] = _mm256_permute2f128_si256(S8_0, S12_0, 0x31);
tmp.packet[21] = _mm256_permute2f128_si256(S9_0, S13_0, 0x31);
tmp.packet[22] = _mm256_permute2f128_si256(S10_0, S14_0, 0x31);
tmp.packet[23] = _mm256_permute2f128_si256(S11_0, S15_0, 0x31);
tmp.packet[24] = _mm256_permute2f128_si256(S8_1, S12_1, 0x20);
tmp.packet[25] = _mm256_permute2f128_si256(S9_1, S13_1, 0x20);
tmp.packet[26] = _mm256_permute2f128_si256(S10_1, S14_1, 0x20);
tmp.packet[27] = _mm256_permute2f128_si256(S11_1, S15_1, 0x20);
tmp.packet[28] = _mm256_permute2f128_si256(S8_1, S12_1, 0x31);
tmp.packet[29] = _mm256_permute2f128_si256(S9_1, S13_1, 0x31);
tmp.packet[30] = _mm256_permute2f128_si256(S10_1, S14_1, 0x31);
tmp.packet[31] = _mm256_permute2f128_si256(S11_1, S15_1, 0x31);
// Pack them into the output
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 0, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 1, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 2, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 3, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 4, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 5, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 6, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 7, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 8, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 9, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 10, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 11, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 12, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 13, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 14, 16);
PACK_OUTPUT_I32(kernel.packet, tmp.packet, 15, 16);
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16i, 4>& kernel) {
__m512i T0 = _mm512_unpacklo_epi32(kernel.packet[0], kernel.packet[1]);
__m512i T1 = _mm512_unpackhi_epi32(kernel.packet[0], kernel.packet[1]);
__m512i T2 = _mm512_unpacklo_epi32(kernel.packet[2], kernel.packet[3]);
__m512i T3 = _mm512_unpackhi_epi32(kernel.packet[2], kernel.packet[3]);
__m512i S0 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S1 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(3, 2, 3, 2));
__m512i S2 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(1, 0, 1, 0));
__m512i S3 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(3, 2, 3, 2));
EIGEN_EXTRACT_8i_FROM_16i(S0, S0);
EIGEN_EXTRACT_8i_FROM_16i(S1, S1);
EIGEN_EXTRACT_8i_FROM_16i(S2, S2);
EIGEN_EXTRACT_8i_FROM_16i(S3, S3);
PacketBlock<Packet8i, 8> tmp;
tmp.packet[0] = _mm256_permute2f128_si256(S0_0, S1_0, 0x20);
tmp.packet[1] = _mm256_permute2f128_si256(S2_0, S3_0, 0x20);
tmp.packet[2] = _mm256_permute2f128_si256(S0_0, S1_0, 0x31);
tmp.packet[3] = _mm256_permute2f128_si256(S2_0, S3_0, 0x31);
tmp.packet[4] = _mm256_permute2f128_si256(S0_1, S1_1, 0x20);
tmp.packet[5] = _mm256_permute2f128_si256(S2_1, S3_1, 0x20);
tmp.packet[6] = _mm256_permute2f128_si256(S0_1, S1_1, 0x31);
tmp.packet[7] = _mm256_permute2f128_si256(S2_1, S3_1, 0x31);
PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 0, 1);
PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 1, 1);
PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 2, 1);
PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 3, 1);
}
template <size_t N>
EIGEN_STRONG_INLINE int avx512_blend_mask(const Selector<N>& ifPacket) {
alignas(__m128i) uint8_t aux[sizeof(__m128i)];
for (size_t i = 0; i < N; i++) aux[i] = static_cast<uint8_t>(ifPacket.select[i]);
__m128i paux = _mm_sub_epi8(_mm_setzero_si128(), _mm_load_si128(reinterpret_cast<const __m128i*>(aux)));
return _mm_movemask_epi8(paux);
}
template <>
EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& ifPacket, const Packet16f& thenPacket,
const Packet16f& elsePacket) {
__mmask16 m = avx512_blend_mask(ifPacket);
return _mm512_mask_blend_ps(m, elsePacket, thenPacket);
}
template <>
EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, const Packet8d& thenPacket,
const Packet8d& elsePacket) {
__mmask8 m = avx512_blend_mask(ifPacket);
return _mm512_mask_blend_pd(m, elsePacket, thenPacket);
}
// Packet math for Eigen::half
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
return _mm256_set1_epi16(from.x);
}
template <>
EIGEN_STRONG_INLINE Eigen::half pfirst<Packet16h>(const Packet16h& from) {
return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm256_extract_epi16(from, 0)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pload<Packet16h>(const Eigen::half* from) {
return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
}
template <>
EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from) {
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
}
template <>
EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
// (void*) -> workaround clang warning:
// cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
EIGEN_DEBUG_ALIGNED_STORE
_mm256_store_si256((__m256i*)(void*)to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
// (void*) -> workaround clang warning:
// cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_si256((__m256i*)(void*)to, from);
}
template <>
EIGEN_STRONG_INLINE Packet16h ploaddup<Packet16h>(const Eigen::half* from) {
unsigned short a = from[0].x;
unsigned short b = from[1].x;
unsigned short c = from[2].x;
unsigned short d = from[3].x;
unsigned short e = from[4].x;
unsigned short f = from[5].x;
unsigned short g = from[6].x;
unsigned short h = from[7].x;
return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
}
template <>
EIGEN_STRONG_INLINE Packet16h ploadquad(const Eigen::half* from) {
unsigned short a = from[0].x;
unsigned short b = from[1].x;
unsigned short c = from[2].x;
unsigned short d = from[3].x;
return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
}
EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) { return _mm512_cvtph_ps(a); }
EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) {
return _mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
}
template <>
EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) {
return Packet16h(ptrue(Packet8i(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) {
const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
return _mm256_andnot_si256(sign_mask, a);
}
template <>
EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a, const Packet16h& b) {
return float2half(pmin<Packet16f>(half2float(a), half2float(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a, const Packet16h& b) {
return float2half(pmax<Packet16f>(half2float(a), half2float(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) {
return float2half(plset<Packet16f>(static_cast<float>(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a, const Packet16h& b) {
// in some cases Packet8i is a wrapper around __m256i, so we need to
// cast to Packet8i to call the correct overload.
return Packet16h(por(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a, const Packet16h& b) {
return Packet16h(pxor(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a, const Packet16h& b) {
return Packet16h(pand(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a, const Packet16h& b) {
return Packet16h(pandnot(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) {
return _mm256_blendv_epi8(b, a, mask);
}
template <>
EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) {
return float2half(pround<Packet16f>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) {
return float2half(print<Packet16f>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) {
return float2half(pceil<Packet16f>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) {
return float2half(pfloor<Packet16f>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16h ptrunc<Packet16h>(const Packet16h& a) {
return float2half(ptrunc<Packet16f>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a, const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
return Pack32To16(pcmp_eq(af, bf));
}
template <>
EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a, const Packet16h& b) {
return Pack32To16(pcmp_le(half2float(a), half2float(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a, const Packet16h& b) {
return Pack32To16(pcmp_lt(half2float(a), half2float(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a, const Packet16h& b) {
return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) {
return a;
}
template <>
EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
Packet16h sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
return _mm256_xor_si256(a, sign_mask);
}
template <>
EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
Packet16f rf = padd(af, bf);
return float2half(rf);
}
template <>
EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
Packet16f rf = psub(af, bf);
return float2half(rf);
}
template <>
EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
Packet16f rf = pmul(af, bf);
return float2half(rf);
}
template <>
EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
Packet16f af = half2float(a);
Packet16f bf = half2float(b);
Packet16f rf = pdiv(af, bf);
return float2half(rf);
}
template <>
EIGEN_STRONG_INLINE Packet16h pmadd<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return float2half(pmadd(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pmsub<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return float2half(pmsub(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pnmadd<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return float2half(pnmadd(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pnmsub<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return float2half(pnmsub(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
Packet16f from_float = half2float(from);
return half(predux(from_float));
}
template <>
EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
Packet8h lane0 = _mm256_extractf128_si256(a, 0);
Packet8h lane1 = _mm256_extractf128_si256(a, 1);
return padd<Packet8h>(lane0, lane1);
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet16h>(const Packet16h& a) {
Packet16f af = half2float(a);
float reduced = predux_max<Packet16f>(af);
return Eigen::half(reduced);
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet16h>(const Packet16h& a) {
Packet16f af = half2float(a);
float reduced = predux_min<Packet16f>(af);
return Eigen::half(reduced);
}
template <>
EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& from) {
Packet16f from_float = half2float(from);
return half(predux_mul(from_float));
}
template <>
EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a) {
__m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
return _mm256_insertf128_si256(_mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(a, 1), m)),
_mm_shuffle_epi8(_mm256_extractf128_si256(a, 0), m), 1);
}
template <>
EIGEN_STRONG_INLINE Packet16h pgather<Eigen::half, Packet16h>(const Eigen::half* from, Index stride) {
return _mm256_set_epi16(from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x,
from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x,
from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
}
template <>
EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Packet16h& from, Index stride) {
EIGEN_ALIGN64 half aux[16];
pstore(aux, from);
to[stride * 0] = aux[0];
to[stride * 1] = aux[1];
to[stride * 2] = aux[2];
to[stride * 3] = aux[3];
to[stride * 4] = aux[4];
to[stride * 5] = aux[5];
to[stride * 6] = aux[6];
to[stride * 7] = aux[7];
to[stride * 8] = aux[8];
to[stride * 9] = aux[9];
to[stride * 10] = aux[10];
to[stride * 11] = aux[11];
to[stride * 12] = aux[12];
to[stride * 13] = aux[13];
to[stride * 14] = aux[14];
to[stride * 15] = aux[15];
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 16>& kernel) {
__m256i a = kernel.packet[0];
__m256i b = kernel.packet[1];
__m256i c = kernel.packet[2];
__m256i d = kernel.packet[3];
__m256i e = kernel.packet[4];
__m256i f = kernel.packet[5];
__m256i g = kernel.packet[6];
__m256i h = kernel.packet[7];
__m256i i = kernel.packet[8];
__m256i j = kernel.packet[9];
__m256i k = kernel.packet[10];
__m256i l = kernel.packet[11];
__m256i m = kernel.packet[12];
__m256i n = kernel.packet[13];
__m256i o = kernel.packet[14];
__m256i p = kernel.packet[15];
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
__m256i ef_07 = _mm256_unpacklo_epi16(e, f);
__m256i gh_07 = _mm256_unpacklo_epi16(g, h);
__m256i ij_07 = _mm256_unpacklo_epi16(i, j);
__m256i kl_07 = _mm256_unpacklo_epi16(k, l);
__m256i mn_07 = _mm256_unpacklo_epi16(m, n);
__m256i op_07 = _mm256_unpacklo_epi16(o, p);
__m256i ab_8f = _mm256_unpackhi_epi16(a, b);
__m256i cd_8f = _mm256_unpackhi_epi16(c, d);
__m256i ef_8f = _mm256_unpackhi_epi16(e, f);
__m256i gh_8f = _mm256_unpackhi_epi16(g, h);
__m256i ij_8f = _mm256_unpackhi_epi16(i, j);
__m256i kl_8f = _mm256_unpackhi_epi16(k, l);
__m256i mn_8f = _mm256_unpackhi_epi16(m, n);
__m256i op_8f = _mm256_unpackhi_epi16(o, p);
__m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
__m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
__m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
__m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
__m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
__m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
__m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
__m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
__m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
__m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
__m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
__m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
__m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
__m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
__m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
__m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
__m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
__m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
__m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
__m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
__m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
__m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
__m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
__m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
__m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
__m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
__m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
__m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
__m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
__m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
__m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
__m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
__m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
__m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
__m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
__m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
__m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
__m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
__m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
__m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
__m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
__m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
__m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
__m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
__m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
__m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
kernel.packet[0] = a_p_0;
kernel.packet[1] = a_p_1;
kernel.packet[2] = a_p_2;
kernel.packet[3] = a_p_3;
kernel.packet[4] = a_p_4;
kernel.packet[5] = a_p_5;
kernel.packet[6] = a_p_6;
kernel.packet[7] = a_p_7;
kernel.packet[8] = a_p_8;
kernel.packet[9] = a_p_9;
kernel.packet[10] = a_p_a;
kernel.packet[11] = a_p_b;
kernel.packet[12] = a_p_c;
kernel.packet[13] = a_p_d;
kernel.packet[14] = a_p_e;
kernel.packet[15] = a_p_f;
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 8>& kernel) {
EIGEN_ALIGN64 half in[8][16];
pstore<half>(in[0], kernel.packet[0]);
pstore<half>(in[1], kernel.packet[1]);
pstore<half>(in[2], kernel.packet[2]);
pstore<half>(in[3], kernel.packet[3]);
pstore<half>(in[4], kernel.packet[4]);
pstore<half>(in[5], kernel.packet[5]);
pstore<half>(in[6], kernel.packet[6]);
pstore<half>(in[7], kernel.packet[7]);
EIGEN_ALIGN64 half out[8][16];
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
out[i][j] = in[j][2 * i];
}
for (int j = 0; j < 8; ++j) {
out[i][j + 8] = in[j][2 * i + 1];
}
}
kernel.packet[0] = pload<Packet16h>(out[0]);
kernel.packet[1] = pload<Packet16h>(out[1]);
kernel.packet[2] = pload<Packet16h>(out[2]);
kernel.packet[3] = pload<Packet16h>(out[3]);
kernel.packet[4] = pload<Packet16h>(out[4]);
kernel.packet[5] = pload<Packet16h>(out[5]);
kernel.packet[6] = pload<Packet16h>(out[6]);
kernel.packet[7] = pload<Packet16h>(out[7]);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 4>& kernel) {
EIGEN_ALIGN64 half in[4][16];
pstore<half>(in[0], kernel.packet[0]);
pstore<half>(in[1], kernel.packet[1]);
pstore<half>(in[2], kernel.packet[2]);
pstore<half>(in[3], kernel.packet[3]);
EIGEN_ALIGN64 half out[4][16];
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
out[i][j] = in[j][4 * i];
}
for (int j = 0; j < 4; ++j) {
out[i][j + 4] = in[j][4 * i + 1];
}
for (int j = 0; j < 4; ++j) {
out[i][j + 8] = in[j][4 * i + 2];
}
for (int j = 0; j < 4; ++j) {
out[i][j + 12] = in[j][4 * i + 3];
}
}
kernel.packet[0] = pload<Packet16h>(out[0]);
kernel.packet[1] = pload<Packet16h>(out[1]);
kernel.packet[2] = pload<Packet16h>(out[2]);
kernel.packet[3] = pload<Packet16h>(out[3]);
}
#endif // EIGEN_VECTORIZE_AVX512FP16
template <>
struct is_arithmetic<Packet16bf> {
enum { value = true };
};
template <>
struct packet_traits<bfloat16> : default_packet_traits {
typedef Packet16bf type;
typedef Packet8bf half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 16,
HasBlend = 0,
HasInsert = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasSqrt = 1,
HasRsqrt = 1,
#ifdef EIGEN_VECTORIZE_AVX512DQ
HasLog = 1, // Currently fails test with bad accuracy.
HasLog1p = 1,
HasExpm1 = 1,
HasNdtri = 1,
HasBessel = 1,
#endif
HasExp = 1,
HasTanh = EIGEN_FAST_MATH,
HasErf = EIGEN_FAST_MATH,
HasCmp = 1,
HasDiv = 1
};
};
template <>
struct unpacket_traits<Packet16bf> {
typedef bfloat16 type;
enum {
size = 16,
alignment = Aligned32,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
typedef Packet8bf half;
};
template <>
EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) {
return _mm256_set1_epi16(from.value);
}
template <>
EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) {
bfloat16 t;
t.value = static_cast<unsigned short>(_mm256_extract_epi16(from, 0));
return t;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) {
return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
}
template <>
EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
}
template <>
EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet16bf& from) {
EIGEN_DEBUG_ALIGNED_STORE
_mm256_store_si256(reinterpret_cast<__m256i*>(to), from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet16bf& from) {
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
}
template <>
EIGEN_STRONG_INLINE Packet16bf ploaddup<Packet16bf>(const bfloat16* from) {
unsigned short a = from[0].value;
unsigned short b = from[1].value;
unsigned short c = from[2].value;
unsigned short d = from[3].value;
unsigned short e = from[4].value;
unsigned short f = from[5].value;
unsigned short g = from[6].value;
unsigned short h = from[7].value;
return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
}
template <>
EIGEN_STRONG_INLINE Packet16bf ploadquad(const bfloat16* from) {
unsigned short a = from[0].value;
unsigned short b = from[1].value;
unsigned short c = from[2].value;
unsigned short d = from[3].value;
return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
}
EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
}
// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
Packet16bf r;
#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_STRICT_AT_LEAST(10, 1, 0)
// Since GCC 10.1 supports avx512bf16 and C style explicit cast
// (C++ static_cast is not supported yet), do conversion via intrinsic
// and register path for performance.
r = (__m256i)(_mm512_cvtneps_pbh(a));
#else
__m512i t;
__m512i input = _mm512_castps_si512(a);
__m512i nan = _mm512_set1_epi32(0x7fc0);
// uint32_t lsb = (input >> 16) & 1;
t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1));
// uint32_t rounding_bias = 0x7fff + lsb;
t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff));
// input += rounding_bias;
t = _mm512_add_epi32(t, input);
// input = input >> 16;
t = _mm512_srli_epi32(t, 16);
// Check NaN before converting back to bf16
__mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
t = _mm512_mask_blend_epi32(mask, nan, t);
// output.value = static_cast<uint16_t>(input);
r = _mm512_cvtepi32_epi16(t);
#endif // EIGEN_VECTORIZE_AVX512BF16
return r;
}
template <>
EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) {
return Packet16bf(ptrue<Packet8i>(Packet8i(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) {
return Packet16bf(por<Packet8i>(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
return Packet16bf(pxor<Packet8i>(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
return Packet16bf(pand<Packet8i>(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a, const Packet16bf& b) {
return Packet16bf(pandnot<Packet8i>(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask, const Packet16bf& a, const Packet16bf& b) {
// Input mask is expected to be all 0/1, handle it with 8-bit
// intrinsic for performance.
return _mm256_blendv_epi8(b, a, mask);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pround<Packet16bf>(const Packet16bf& a) {
return F32ToBf16(pround<Packet16f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf print<Packet16bf>(const Packet16bf& a) {
return F32ToBf16(print<Packet16f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pceil<Packet16bf>(const Packet16bf& a) {
return F32ToBf16(pceil<Packet16f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pfloor<Packet16bf>(const Packet16bf& a) {
return F32ToBf16(pfloor<Packet16f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf ptrunc<Packet16bf>(const Packet16bf& a) {
return F32ToBf16(ptrunc<Packet16f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, const Packet16bf& b) {
return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a, const Packet16bf& b) {
return Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a, const Packet16bf& b) {
return Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a, const Packet16bf& b) {
return Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
Packet16bf sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
return _mm256_xor_si256(a, sign_mask);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) {
return a;
}
template <>
EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) {
const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
return _mm256_andnot_si256(sign_mask, a);
}
template <>
EIGEN_STRONG_INLINE Packet16bf padd<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
return F32ToBf16(padd<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf psub<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
return F32ToBf16(psub<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pmul<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pmadd<Packet16bf>(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) {
return F32ToBf16(pmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pmsub<Packet16bf>(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) {
return F32ToBf16(pmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pnmadd<Packet16bf>(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) {
return F32ToBf16(pnmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pnmsub<Packet16bf>(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) {
return F32ToBf16(pnmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pdiv<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
return F32ToBf16(pdiv<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pmin<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
return F32ToBf16(pmin<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf plset<Packet16bf>(const bfloat16& a) {
return F32ToBf16(plset<Packet16f>(static_cast<float>(a)));
}
template <>
EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
Packet8bf lane0 = _mm256_extractf128_si256(a, 0);
Packet8bf lane1 = _mm256_extractf128_si256(a, 1);
return padd<Packet8bf>(lane0, lane1);
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet16bf>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_mul<Packet16f>(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min<Packet16bf>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_min<Packet16f>(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max<Packet16bf>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_max<Packet16f>(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
__m256i m = _mm256_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, 14, 15, 12, 13, 10, 11, 8, 9, 6, 7,
4, 5, 2, 3, 0, 1);
Packet16bf res;
// Swap hi and lo first because shuffle is in 128-bit lanes.
res = _mm256_permute2x128_si256(a, a, 1);
// Shuffle 8-bit values in src within 2*128-bit lanes.
return _mm256_shuffle_epi8(res, m);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from, Index stride) {
return _mm256_set_epi16(
from[15 * stride].value, from[14 * stride].value, from[13 * stride].value, from[12 * stride].value,
from[11 * stride].value, from[10 * stride].value, from[9 * stride].value, from[8 * stride].value,
from[7 * stride].value, from[6 * stride].value, from[5 * stride].value, from[4 * stride].value,
from[3 * stride].value, from[2 * stride].value, from[1 * stride].value, from[0 * stride].value);
}
template <>
EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to, const Packet16bf& from, Index stride) {
EIGEN_ALIGN64 bfloat16 aux[16];
pstore(aux, from);
to[stride * 0] = aux[0];
to[stride * 1] = aux[1];
to[stride * 2] = aux[2];
to[stride * 3] = aux[3];
to[stride * 4] = aux[4];
to[stride * 5] = aux[5];
to[stride * 6] = aux[6];
to[stride * 7] = aux[7];
to[stride * 8] = aux[8];
to[stride * 9] = aux[9];
to[stride * 10] = aux[10];
to[stride * 11] = aux[11];
to[stride * 12] = aux[12];
to[stride * 13] = aux[13];
to[stride * 14] = aux[14];
to[stride * 15] = aux[15];
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf, 16>& kernel) {
__m256i a = kernel.packet[0];
__m256i b = kernel.packet[1];
__m256i c = kernel.packet[2];
__m256i d = kernel.packet[3];
__m256i e = kernel.packet[4];
__m256i f = kernel.packet[5];
__m256i g = kernel.packet[6];
__m256i h = kernel.packet[7];
__m256i i = kernel.packet[8];
__m256i j = kernel.packet[9];
__m256i k = kernel.packet[10];
__m256i l = kernel.packet[11];
__m256i m = kernel.packet[12];
__m256i n = kernel.packet[13];
__m256i o = kernel.packet[14];
__m256i p = kernel.packet[15];
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
__m256i ef_07 = _mm256_unpacklo_epi16(e, f);
__m256i gh_07 = _mm256_unpacklo_epi16(g, h);
__m256i ij_07 = _mm256_unpacklo_epi16(i, j);
__m256i kl_07 = _mm256_unpacklo_epi16(k, l);
__m256i mn_07 = _mm256_unpacklo_epi16(m, n);
__m256i op_07 = _mm256_unpacklo_epi16(o, p);
__m256i ab_8f = _mm256_unpackhi_epi16(a, b);
__m256i cd_8f = _mm256_unpackhi_epi16(c, d);
__m256i ef_8f = _mm256_unpackhi_epi16(e, f);
__m256i gh_8f = _mm256_unpackhi_epi16(g, h);
__m256i ij_8f = _mm256_unpackhi_epi16(i, j);
__m256i kl_8f = _mm256_unpackhi_epi16(k, l);
__m256i mn_8f = _mm256_unpackhi_epi16(m, n);
__m256i op_8f = _mm256_unpackhi_epi16(o, p);
__m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
__m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
__m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
__m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
__m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
__m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
__m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
__m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
__m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
__m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
__m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
__m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
__m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
__m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
__m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
__m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
__m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
__m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
__m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
__m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
__m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
__m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
__m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
__m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
__m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
__m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
__m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
__m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
__m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
__m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
kernel.packet[0] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
kernel.packet[1] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
kernel.packet[2] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
kernel.packet[3] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
kernel.packet[4] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
kernel.packet[5] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
kernel.packet[6] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
kernel.packet[7] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
kernel.packet[8] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
kernel.packet[9] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
kernel.packet[10] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
kernel.packet[11] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
kernel.packet[12] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
kernel.packet[13] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
kernel.packet[14] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
kernel.packet[15] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf, 4>& kernel) {
__m256i a = kernel.packet[0];
__m256i b = kernel.packet[1];
__m256i c = kernel.packet[2];
__m256i d = kernel.packet[3];
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
__m256i ab_8f = _mm256_unpackhi_epi16(a, b);
__m256i cd_8f = _mm256_unpackhi_epi16(c, d);
__m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
__m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
__m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
kernel.packet[0] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
kernel.packet[1] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
kernel.packet[2] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
}
// Minimal implementation of 16-bit int packets for use in pfrexp, pldexp.
template <>
EIGEN_STRONG_INLINE Packet32s pset1<Packet32s>(const numext::int16_t& x) {
return _mm512_set1_epi16(x);
}
template <>
EIGEN_STRONG_INLINE Packet16s pset1<Packet16s>(const numext::int16_t& x) {
return _mm256_set1_epi16(x);
}
template <>
EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const numext::int16_t& x) {
return _mm_set1_epi16(x);
}
template <>
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
EIGEN_DEBUG_ALIGNED_STORE
_mm512_store_epi32(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
EIGEN_DEBUG_ALIGNED_STORE
#if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
_mm256_store_epi32(out, x);
#else
_mm256_store_si256(reinterpret_cast<__m256i*>(out), x);
#endif
}
template <>
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
EIGEN_DEBUG_ALIGNED_STORE
#if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
_mm256_store_epi32(out, x);
#else
_mm_store_si128(reinterpret_cast<__m128i*>(out), x);
#endif
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
EIGEN_DEBUG_UNALIGNED_STORE
_mm512_storeu_epi32(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
EIGEN_DEBUG_UNALIGNED_STORE
_mm256_storeu_epi32(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
EIGEN_DEBUG_UNALIGNED_STORE
_mm_storeu_epi32(out, x);
}
template <>
EIGEN_STRONG_INLINE Packet32s padd(const Packet32s& a, const Packet32s& b) {
return _mm512_add_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16s padd(const Packet16s& a, const Packet16s& b) {
return _mm256_add_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8s padd(const Packet8s& a, const Packet8s& b) {
return _mm_add_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet32s psub(const Packet32s& a, const Packet32s& b) {
return _mm512_sub_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16s psub(const Packet16s& a, const Packet16s& b) {
return _mm256_sub_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8s psub(const Packet8s& a, const Packet8s& b) {
return _mm_sub_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet32s pmul(const Packet32s& a, const Packet32s& b) {
return _mm512_mullo_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16s pmul(const Packet16s& a, const Packet16s& b) {
return _mm256_mullo_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8s pmul(const Packet8s& a, const Packet8s& b) {
return _mm_mullo_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet32s pnegate(const Packet32s& a) {
return _mm512_sub_epi16(_mm512_setzero_si512(), a);
}
template <>
EIGEN_STRONG_INLINE Packet16s pnegate(const Packet16s& a) {
return _mm256_sub_epi16(_mm256_setzero_si256(), a);
}
template <>
EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) {
return _mm_sub_epi16(_mm_setzero_si128(), a);
}
template <int N>
EIGEN_STRONG_INLINE Packet32s parithmetic_shift_right(Packet32s a) {
return _mm512_srai_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet16s parithmetic_shift_right(Packet16s a) {
return _mm256_srai_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet8s parithmetic_shift_right(Packet8s a) {
return _mm_srai_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet32s plogical_shift_left(Packet32s a) {
return _mm512_slli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet16s plogical_shift_left(Packet16s a) {
return _mm256_slli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet8s plogical_shift_left(Packet8s a) {
return _mm_slli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet32s plogical_shift_right(Packet32s a) {
return _mm512_srli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet16s plogical_shift_right(Packet16s a) {
return _mm256_srli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet8s plogical_shift_right(Packet8s a) {
return _mm_srli_epi16(a, N);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_PACKET_MATH_AVX512_H
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_PACKET_MATH_FP16_AVX512_H
#define EIGEN_PACKET_MATH_FP16_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
typedef __m512h Packet32h;
typedef __m256h Packet16h;
typedef __m128h Packet8h;
template <>
struct is_arithmetic<Packet8h> {
enum { value = true };
};
template <>
struct packet_traits<half> : default_packet_traits {
typedef Packet32h type;
typedef Packet16h half;
enum {
Vectorizable = 1,
AlignedOnScalar = 1,
size = 32,
HasCmp = 1,
HasAdd = 1,
HasSub = 1,
HasMul = 1,
HasDiv = 1,
HasNegate = 1,
HasAbs = 1,
HasAbs2 = 0,
HasMin = 1,
HasMax = 1,
HasConj = 1,
HasSetLinear = 0,
HasLog = 1,
HasLog1p = 1,
HasExp = 1,
HasExpm1 = 1,
HasSqrt = 1,
HasRsqrt = 1,
// These ones should be implemented in future
HasBessel = 0,
HasNdtri = 0,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasTanh = EIGEN_FAST_MATH,
HasErf = 0, // EIGEN_FAST_MATH,
HasBlend = 0
};
};
template <>
struct unpacket_traits<Packet32h> {
typedef Eigen::half type;
typedef Packet16h half;
typedef Packet32s integer_packet;
enum {
size = 32,
alignment = Aligned64,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
struct unpacket_traits<Packet16h> {
typedef Eigen::half type;
typedef Packet8h half;
typedef Packet16s integer_packet;
enum {
size = 16,
alignment = Aligned32,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
struct unpacket_traits<Packet8h> {
typedef Eigen::half type;
typedef Packet8h half;
typedef Packet8s integer_packet;
enum {
size = 8,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
// Conversions
EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) { return _mm512_cvtxph_ps(a); }
EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) { return _mm256_cvtxph_ps(a); }
EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) { return _mm512_cvtxps_ph(a); }
EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { return _mm256_cvtxps_ph(a); }
// Memory functions
// pset1
template <>
EIGEN_STRONG_INLINE Packet32h pset1<Packet32h>(const Eigen::half& from) {
return _mm512_set1_ph(from.x);
}
template <>
EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
return _mm256_set1_ph(from.x);
}
template <>
EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
return _mm_set1_ph(from.x);
}
template <>
EIGEN_STRONG_INLINE Packet32h pzero(const Packet32h& /*a*/) {
return _mm512_setzero_ph();
}
template <>
EIGEN_STRONG_INLINE Packet16h pzero(const Packet16h& /*a*/) {
return _mm256_setzero_ph();
}
template <>
EIGEN_STRONG_INLINE Packet8h pzero(const Packet8h& /*a*/) {
return _mm_setzero_ph();
}
// pset1frombits
template <>
EIGEN_STRONG_INLINE Packet32h pset1frombits<Packet32h>(unsigned short from) {
return _mm512_castsi512_ph(_mm512_set1_epi16(from));
}
template <>
EIGEN_STRONG_INLINE Packet16h pset1frombits<Packet16h>(unsigned short from) {
return _mm256_castsi256_ph(_mm256_set1_epi16(from));
}
template <>
EIGEN_STRONG_INLINE Packet8h pset1frombits<Packet8h>(unsigned short from) {
return _mm_castsi128_ph(_mm_set1_epi16(from));
}
// pfirst
template <>
EIGEN_STRONG_INLINE Eigen::half pfirst<Packet32h>(const Packet32h& from) {
return Eigen::half(_mm512_cvtsh_h(from));
}
template <>
EIGEN_STRONG_INLINE Eigen::half pfirst<Packet16h>(const Packet16h& from) {
return Eigen::half(_mm256_cvtsh_h(from));
}
template <>
EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) {
return Eigen::half(_mm_cvtsh_h(from));
}
// pload
template <>
EIGEN_STRONG_INLINE Packet32h pload<Packet32h>(const Eigen::half* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ph(from);
}
template <>
EIGEN_STRONG_INLINE Packet16h pload<Packet16h>(const Eigen::half* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_ph(from);
}
template <>
EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_ph(from);
}
// ploadu
template <>
EIGEN_STRONG_INLINE Packet32h ploadu<Packet32h>(const Eigen::half* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ph(from);
}
template <>
EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_ph(from);
}
template <>
EIGEN_STRONG_INLINE Packet8h ploadu<Packet8h>(const Eigen::half* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_ph(from);
}
// pstore
template <>
EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet32h& from) {
EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ph(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
EIGEN_DEBUG_ALIGNED_STORE _mm256_store_ph(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet8h& from) {
EIGEN_DEBUG_ALIGNED_STORE _mm_store_ph(to, from);
}
// pstoreu
template <>
EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet32h& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ph(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_ph(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet8h& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_ph(to, from);
}
// ploaddup
template <>
EIGEN_STRONG_INLINE Packet32h ploaddup<Packet32h>(const Eigen::half* from) {
__m512h a = _mm512_castph256_ph512(_mm256_loadu_ph(from));
return _mm512_permutexvar_ph(_mm512_set_epi16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6,
5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0),
a);
}
template <>
EIGEN_STRONG_INLINE Packet16h ploaddup<Packet16h>(const Eigen::half* from) {
__m256h a = _mm256_castph128_ph256(_mm_loadu_ph(from));
return _mm256_permutexvar_ph(_mm256_set_epi16(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0), a);
}
template <>
EIGEN_STRONG_INLINE Packet8h ploaddup<Packet8h>(const Eigen::half* from) {
return _mm_set_ph(from[3].x, from[3].x, from[2].x, from[2].x, from[1].x, from[1].x, from[0].x, from[0].x);
}
// ploadquad
template <>
EIGEN_STRONG_INLINE Packet32h ploadquad<Packet32h>(const Eigen::half* from) {
__m512h a = _mm512_castph128_ph512(_mm_loadu_ph(from));
return _mm512_permutexvar_ph(
_mm512_set_epi16(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0),
a);
}
template <>
EIGEN_STRONG_INLINE Packet16h ploadquad<Packet16h>(const Eigen::half* from) {
return _mm256_set_ph(from[3].x, from[3].x, from[3].x, from[3].x, from[2].x, from[2].x, from[2].x, from[2].x,
from[1].x, from[1].x, from[1].x, from[1].x, from[0].x, from[0].x, from[0].x, from[0].x);
}
template <>
EIGEN_STRONG_INLINE Packet8h ploadquad<Packet8h>(const Eigen::half* from) {
return _mm_set_ph(from[1].x, from[1].x, from[1].x, from[1].x, from[0].x, from[0].x, from[0].x, from[0].x);
}
// pabs
template <>
EIGEN_STRONG_INLINE Packet32h pabs<Packet32h>(const Packet32h& a) {
return _mm512_abs_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet16h pabs<Packet16h>(const Packet16h& a) {
return _mm256_abs_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet8h pabs<Packet8h>(const Packet8h& a) {
return _mm_abs_ph(a);
}
// psignbit
template <>
EIGEN_STRONG_INLINE Packet32h psignbit<Packet32h>(const Packet32h& a) {
return _mm512_castsi512_ph(_mm512_srai_epi16(_mm512_castph_si512(a), 15));
}
template <>
EIGEN_STRONG_INLINE Packet16h psignbit<Packet16h>(const Packet16h& a) {
return _mm256_castsi256_ph(_mm256_srai_epi16(_mm256_castph_si256(a), 15));
}
template <>
EIGEN_STRONG_INLINE Packet8h psignbit<Packet8h>(const Packet8h& a) {
return _mm_castsi128_ph(_mm_srai_epi16(_mm_castph_si128(a), 15));
}
// pmin
template <>
EIGEN_STRONG_INLINE Packet32h pmin<Packet32h>(const Packet32h& a, const Packet32h& b) {
return _mm512_min_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a, const Packet16h& b) {
return _mm256_min_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h pmin<Packet8h>(const Packet8h& a, const Packet8h& b) {
return _mm_min_ph(a, b);
}
// pmax
template <>
EIGEN_STRONG_INLINE Packet32h pmax<Packet32h>(const Packet32h& a, const Packet32h& b) {
return _mm512_max_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a, const Packet16h& b) {
return _mm256_max_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h pmax<Packet8h>(const Packet8h& a, const Packet8h& b) {
return _mm_max_ph(a, b);
}
// plset
template <>
EIGEN_STRONG_INLINE Packet32h plset<Packet32h>(const half& a) {
return _mm512_add_ph(pset1<Packet32h>(a), _mm512_set_ph(31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17,
16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
}
template <>
EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) {
return _mm256_add_ph(pset1<Packet16h>(a), _mm256_set_ph(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
}
template <>
EIGEN_STRONG_INLINE Packet8h plset<Packet8h>(const half& a) {
return _mm_add_ph(pset1<Packet8h>(a), _mm_set_ph(7, 6, 5, 4, 3, 2, 1, 0));
}
// por
template <>
EIGEN_STRONG_INLINE Packet32h por(const Packet32h& a, const Packet32h& b) {
return _mm512_castsi512_ph(_mm512_or_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a, const Packet16h& b) {
return _mm256_castsi256_ph(_mm256_or_si256(_mm256_castph_si256(a), _mm256_castph_si256(b)));
}
template <>
EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a, const Packet8h& b) {
return _mm_castsi128_ph(_mm_or_si128(_mm_castph_si128(a), _mm_castph_si128(b)));
}
// pxor
template <>
EIGEN_STRONG_INLINE Packet32h pxor(const Packet32h& a, const Packet32h& b) {
return _mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a, const Packet16h& b) {
return _mm256_castsi256_ph(_mm256_xor_si256(_mm256_castph_si256(a), _mm256_castph_si256(b)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pxor(const Packet8h& a, const Packet8h& b) {
return _mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(a), _mm_castph_si128(b)));
}
// pand
template <>
EIGEN_STRONG_INLINE Packet32h pand(const Packet32h& a, const Packet32h& b) {
return _mm512_castsi512_ph(_mm512_and_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a, const Packet16h& b) {
return _mm256_castsi256_ph(_mm256_and_si256(_mm256_castph_si256(a), _mm256_castph_si256(b)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pand(const Packet8h& a, const Packet8h& b) {
return _mm_castsi128_ph(_mm_and_si128(_mm_castph_si128(a), _mm_castph_si128(b)));
}
// pandnot
template <>
EIGEN_STRONG_INLINE Packet32h pandnot(const Packet32h& a, const Packet32h& b) {
return _mm512_castsi512_ph(_mm512_andnot_si512(_mm512_castph_si512(b), _mm512_castph_si512(a)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a, const Packet16h& b) {
return _mm256_castsi256_ph(_mm256_andnot_si256(_mm256_castph_si256(b), _mm256_castph_si256(a)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a, const Packet8h& b) {
return _mm_castsi128_ph(_mm_andnot_si128(_mm_castph_si128(b), _mm_castph_si128(a)));
}
// pselect
template <>
EIGEN_DEVICE_FUNC inline Packet32h pselect(const Packet32h& mask, const Packet32h& a, const Packet32h& b) {
__mmask32 mask32 = _mm512_cmp_epi16_mask(_mm512_castph_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
return _mm512_mask_blend_ph(mask32, a, b);
}
template <>
EIGEN_DEVICE_FUNC inline Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) {
__mmask16 mask16 = _mm256_cmp_epi16_mask(_mm256_castph_si256(mask), _mm256_setzero_si256(), _MM_CMPINT_EQ);
return _mm256_mask_blend_ph(mask16, a, b);
}
template <>
EIGEN_DEVICE_FUNC inline Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) {
__mmask8 mask8 = _mm_cmp_epi16_mask(_mm_castph_si128(mask), _mm_setzero_si128(), _MM_CMPINT_EQ);
return _mm_mask_blend_ph(mask8, a, b);
}
// pcmp_eq
template <>
EIGEN_STRONG_INLINE Packet32h pcmp_eq(const Packet32h& a, const Packet32h& b) {
__mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_EQ_OQ);
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a, const Packet16h& b) {
__mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_EQ_OQ);
return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a, const Packet8h& b) {
__mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_EQ_OQ);
return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
// pcmp_le
template <>
EIGEN_STRONG_INLINE Packet32h pcmp_le(const Packet32h& a, const Packet32h& b) {
__mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LE_OQ);
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a, const Packet16h& b) {
__mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_LE_OQ);
return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a, const Packet8h& b) {
__mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_LE_OQ);
return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
// pcmp_lt
template <>
EIGEN_STRONG_INLINE Packet32h pcmp_lt(const Packet32h& a, const Packet32h& b) {
__mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LT_OQ);
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a, const Packet16h& b) {
__mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_LT_OQ);
return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a, const Packet8h& b) {
__mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_LT_OQ);
return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
// pcmp_lt_or_nan
template <>
EIGEN_STRONG_INLINE Packet32h pcmp_lt_or_nan(const Packet32h& a, const Packet32h& b) {
__mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_NGE_UQ);
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi16(0), mask, static_cast<short>(0xffffu)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a, const Packet16h& b) {
__mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_NGE_UQ);
return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a, const Packet8h& b) {
__mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_NGE_UQ);
return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast<short>(0xffffu)));
}
// padd
template <>
EIGEN_STRONG_INLINE Packet32h padd<Packet32h>(const Packet32h& a, const Packet32h& b) {
return _mm512_add_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
return _mm256_add_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
return _mm_add_ph(a, b);
}
// psub
template <>
EIGEN_STRONG_INLINE Packet32h psub<Packet32h>(const Packet32h& a, const Packet32h& b) {
return _mm512_sub_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
return _mm256_sub_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h psub<Packet8h>(const Packet8h& a, const Packet8h& b) {
return _mm_sub_ph(a, b);
}
// pmul
template <>
EIGEN_STRONG_INLINE Packet32h pmul<Packet32h>(const Packet32h& a, const Packet32h& b) {
return _mm512_mul_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
return _mm256_mul_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
return _mm_mul_ph(a, b);
}
// pdiv
template <>
EIGEN_STRONG_INLINE Packet32h pdiv<Packet32h>(const Packet32h& a, const Packet32h& b) {
return _mm512_div_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
return _mm256_div_ph(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
return _mm_div_ph(a, b);
;
}
// pround
template <>
EIGEN_STRONG_INLINE Packet32h pround<Packet32h>(const Packet32h& a) {
// Work-around for default std::round rounding mode.
// Mask for the sign bit.
const Packet32h signMask =
pset1frombits<Packet32h>(static_cast<numext::uint16_t>(static_cast<std::uint16_t>(0x8000u)));
// The largest half-precision float less than 0.5.
const Packet32h prev0dot5 = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x37FFu));
return _mm512_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
}
template <>
EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) {
// Work-around for default std::round rounding mode.
// Mask for the sign bit.
const Packet16h signMask =
pset1frombits<Packet16h>(static_cast<numext::uint16_t>(static_cast<std::uint16_t>(0x8000u)));
// The largest half-precision float less than 0.5.
const Packet16h prev0dot5 = pset1frombits<Packet16h>(static_cast<numext::uint16_t>(0x37FFu));
return _mm256_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
}
template <>
EIGEN_STRONG_INLINE Packet8h pround<Packet8h>(const Packet8h& a) {
// Work-around for default std::round rounding mode.
// Mask for the sign bit.
const Packet8h signMask = pset1frombits<Packet8h>(static_cast<numext::uint16_t>(static_cast<std::uint16_t>(0x8000u)));
// The largest half-precision float less than 0.5.
const Packet8h prev0dot5 = pset1frombits<Packet8h>(static_cast<numext::uint16_t>(0x37FFu));
return _mm_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
}
// print
template <>
EIGEN_STRONG_INLINE Packet32h print<Packet32h>(const Packet32h& a) {
return _mm512_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
}
template <>
EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) {
return _mm256_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
}
template <>
EIGEN_STRONG_INLINE Packet8h print<Packet8h>(const Packet8h& a) {
return _mm_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
}
// pceil
template <>
EIGEN_STRONG_INLINE Packet32h pceil<Packet32h>(const Packet32h& a) {
return _mm512_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
}
template <>
EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) {
return _mm256_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
}
template <>
EIGEN_STRONG_INLINE Packet8h pceil<Packet8h>(const Packet8h& a) {
return _mm_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
}
// pfloor
template <>
EIGEN_STRONG_INLINE Packet32h pfloor<Packet32h>(const Packet32h& a) {
return _mm512_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
}
template <>
EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) {
return _mm256_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
}
template <>
EIGEN_STRONG_INLINE Packet8h pfloor<Packet8h>(const Packet8h& a) {
return _mm_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
}
// ptrunc
template <>
EIGEN_STRONG_INLINE Packet32h ptrunc<Packet32h>(const Packet32h& a) {
return _mm512_roundscale_ph(a, _MM_FROUND_TO_ZERO);
}
template <>
EIGEN_STRONG_INLINE Packet16h ptrunc<Packet16h>(const Packet16h& a) {
return _mm256_roundscale_ph(a, _MM_FROUND_TO_ZERO);
}
template <>
EIGEN_STRONG_INLINE Packet8h ptrunc<Packet8h>(const Packet8h& a) {
return _mm_roundscale_ph(a, _MM_FROUND_TO_ZERO);
}
// predux
template <>
EIGEN_STRONG_INLINE half predux<Packet32h>(const Packet32h& a) {
return half(_mm512_reduce_add_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& a) {
return half(_mm256_reduce_add_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux<Packet8h>(const Packet8h& a) {
return half(_mm_reduce_add_ph(a));
}
// predux_half_dowto4
template <>
EIGEN_STRONG_INLINE Packet16h predux_half_dowto4<Packet32h>(const Packet32h& a) {
const __m512i bits = _mm512_castph_si512(a);
Packet16h lo = _mm256_castsi256_ph(_mm512_castsi512_si256(bits));
Packet16h hi = _mm256_castsi256_ph(_mm512_extracti64x4_epi64(bits, 1));
return padd(lo, hi);
}
template <>
EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
Packet8h lo = _mm_castsi128_ph(_mm256_castsi256_si128(_mm256_castph_si256(a)));
Packet8h hi = _mm_castps_ph(_mm256_extractf128_ps(_mm256_castph_ps(a), 1));
return padd(lo, hi);
}
// predux_max
template <>
EIGEN_STRONG_INLINE half predux_max<Packet32h>(const Packet32h& a) {
return half(_mm512_reduce_max_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux_max<Packet16h>(const Packet16h& a) {
return half(_mm256_reduce_max_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux_max<Packet8h>(const Packet8h& a) {
return half(_mm_reduce_max_ph(a));
}
// predux_min
template <>
EIGEN_STRONG_INLINE half predux_min<Packet32h>(const Packet32h& a) {
return half(_mm512_reduce_min_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux_min<Packet16h>(const Packet16h& a) {
return half(_mm256_reduce_min_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux_min<Packet8h>(const Packet8h& a) {
return half(_mm_reduce_min_ph(a));
}
// predux_mul
template <>
EIGEN_STRONG_INLINE half predux_mul<Packet32h>(const Packet32h& a) {
return half(_mm512_reduce_mul_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& a) {
return half(_mm256_reduce_mul_ph(a));
}
template <>
EIGEN_STRONG_INLINE half predux_mul<Packet8h>(const Packet8h& a) {
return half(_mm_reduce_mul_ph(a));
}
#ifdef EIGEN_VECTORIZE_FMA
// pmadd
template <>
EIGEN_STRONG_INLINE Packet32h pmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
return _mm512_fmadd_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16h pmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return _mm256_fmadd_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8h pmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
return _mm_fmadd_ph(a, b, c);
}
// pmsub
template <>
EIGEN_STRONG_INLINE Packet32h pmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
return _mm512_fmsub_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16h pmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return _mm256_fmsub_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8h pmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
return _mm_fmsub_ph(a, b, c);
}
// pnmadd
template <>
EIGEN_STRONG_INLINE Packet32h pnmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
return _mm512_fnmadd_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16h pnmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return _mm256_fnmadd_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8h pnmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
return _mm_fnmadd_ph(a, b, c);
}
// pnmsub
template <>
EIGEN_STRONG_INLINE Packet32h pnmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
return _mm512_fnmsub_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16h pnmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return _mm256_fnmsub_ph(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8h pnmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
return _mm_fnmsub_ph(a, b, c);
}
#endif
// pnegate
template <>
EIGEN_STRONG_INLINE Packet32h pnegate<Packet32h>(const Packet32h& a) {
return _mm512_castsi512_ph(
_mm512_xor_si512(_mm512_castph_si512(a), _mm512_set1_epi16(static_cast<std::uint16_t>(0x8000u))));
}
template <>
EIGEN_STRONG_INLINE Packet16h pnegate<Packet16h>(const Packet16h& a) {
return _mm256_castsi256_ph(
_mm256_xor_si256(_mm256_castph_si256(a), _mm256_set1_epi16(static_cast<std::uint16_t>(0x8000u))));
}
template <>
EIGEN_STRONG_INLINE Packet8h pnegate<Packet8h>(const Packet8h& a) {
return _mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(a), _mm_set1_epi16(static_cast<std::uint16_t>(0x8000u))));
}
// pconj
// Nothing, packets are real.
// psqrt
template <>
EIGEN_STRONG_INLINE Packet32h psqrt<Packet32h>(const Packet32h& a) {
return generic_sqrt_newton_step<Packet32h>::run(a, _mm512_rsqrt_ph(a));
}
template <>
EIGEN_STRONG_INLINE Packet16h psqrt<Packet16h>(const Packet16h& a) {
return generic_sqrt_newton_step<Packet16h>::run(a, _mm256_rsqrt_ph(a));
}
template <>
EIGEN_STRONG_INLINE Packet8h psqrt<Packet8h>(const Packet8h& a) {
return generic_sqrt_newton_step<Packet8h>::run(a, _mm_rsqrt_ph(a));
}
// prsqrt
template <>
EIGEN_STRONG_INLINE Packet32h prsqrt<Packet32h>(const Packet32h& a) {
return generic_rsqrt_newton_step<Packet32h, /*Steps=*/1>::run(a, _mm512_rsqrt_ph(a));
}
template <>
EIGEN_STRONG_INLINE Packet16h prsqrt<Packet16h>(const Packet16h& a) {
return generic_rsqrt_newton_step<Packet16h, /*Steps=*/1>::run(a, _mm256_rsqrt_ph(a));
}
template <>
EIGEN_STRONG_INLINE Packet8h prsqrt<Packet8h>(const Packet8h& a) {
return generic_rsqrt_newton_step<Packet8h, /*Steps=*/1>::run(a, _mm_rsqrt_ph(a));
}
// preciprocal
template <>
EIGEN_STRONG_INLINE Packet32h preciprocal<Packet32h>(const Packet32h& a) {
return generic_reciprocal_newton_step<Packet32h, /*Steps=*/1>::run(a, _mm512_rcp_ph(a));
}
template <>
EIGEN_STRONG_INLINE Packet16h preciprocal<Packet16h>(const Packet16h& a) {
return generic_reciprocal_newton_step<Packet16h, /*Steps=*/1>::run(a, _mm256_rcp_ph(a));
}
template <>
EIGEN_STRONG_INLINE Packet8h preciprocal<Packet8h>(const Packet8h& a) {
return generic_reciprocal_newton_step<Packet8h, /*Steps=*/1>::run(a, _mm_rcp_ph(a));
}
// ptranspose
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 32>& a) {
__m512i t[32];
EIGEN_UNROLL_LOOP
for (int i = 0; i < 16; i++) {
t[2 * i] = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
t[2 * i + 1] =
_mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
}
__m512i p[32];
EIGEN_UNROLL_LOOP
for (int i = 0; i < 8; i++) {
p[4 * i] = _mm512_unpacklo_epi32(t[4 * i], t[4 * i + 2]);
p[4 * i + 1] = _mm512_unpackhi_epi32(t[4 * i], t[4 * i + 2]);
p[4 * i + 2] = _mm512_unpacklo_epi32(t[4 * i + 1], t[4 * i + 3]);
p[4 * i + 3] = _mm512_unpackhi_epi32(t[4 * i + 1], t[4 * i + 3]);
}
__m512i q[32];
EIGEN_UNROLL_LOOP
for (int i = 0; i < 4; i++) {
q[8 * i] = _mm512_unpacklo_epi64(p[8 * i], p[8 * i + 4]);
q[8 * i + 1] = _mm512_unpackhi_epi64(p[8 * i], p[8 * i + 4]);
q[8 * i + 2] = _mm512_unpacklo_epi64(p[8 * i + 1], p[8 * i + 5]);
q[8 * i + 3] = _mm512_unpackhi_epi64(p[8 * i + 1], p[8 * i + 5]);
q[8 * i + 4] = _mm512_unpacklo_epi64(p[8 * i + 2], p[8 * i + 6]);
q[8 * i + 5] = _mm512_unpackhi_epi64(p[8 * i + 2], p[8 * i + 6]);
q[8 * i + 6] = _mm512_unpacklo_epi64(p[8 * i + 3], p[8 * i + 7]);
q[8 * i + 7] = _mm512_unpackhi_epi64(p[8 * i + 3], p[8 * i + 7]);
}
__m512i f[32];
#define PACKET32H_TRANSPOSE_HELPER(X, Y) \
do { \
f[Y * 8] = _mm512_inserti32x4(f[Y * 8], _mm512_extracti32x4_epi32(q[X * 8], Y), X); \
f[Y * 8 + 1] = _mm512_inserti32x4(f[Y * 8 + 1], _mm512_extracti32x4_epi32(q[X * 8 + 1], Y), X); \
f[Y * 8 + 2] = _mm512_inserti32x4(f[Y * 8 + 2], _mm512_extracti32x4_epi32(q[X * 8 + 2], Y), X); \
f[Y * 8 + 3] = _mm512_inserti32x4(f[Y * 8 + 3], _mm512_extracti32x4_epi32(q[X * 8 + 3], Y), X); \
f[Y * 8 + 4] = _mm512_inserti32x4(f[Y * 8 + 4], _mm512_extracti32x4_epi32(q[X * 8 + 4], Y), X); \
f[Y * 8 + 5] = _mm512_inserti32x4(f[Y * 8 + 5], _mm512_extracti32x4_epi32(q[X * 8 + 5], Y), X); \
f[Y * 8 + 6] = _mm512_inserti32x4(f[Y * 8 + 6], _mm512_extracti32x4_epi32(q[X * 8 + 6], Y), X); \
f[Y * 8 + 7] = _mm512_inserti32x4(f[Y * 8 + 7], _mm512_extracti32x4_epi32(q[X * 8 + 7], Y), X); \
} while (false);
PACKET32H_TRANSPOSE_HELPER(0, 0);
PACKET32H_TRANSPOSE_HELPER(1, 1);
PACKET32H_TRANSPOSE_HELPER(2, 2);
PACKET32H_TRANSPOSE_HELPER(3, 3);
PACKET32H_TRANSPOSE_HELPER(1, 0);
PACKET32H_TRANSPOSE_HELPER(2, 0);
PACKET32H_TRANSPOSE_HELPER(3, 0);
PACKET32H_TRANSPOSE_HELPER(2, 1);
PACKET32H_TRANSPOSE_HELPER(3, 1);
PACKET32H_TRANSPOSE_HELPER(3, 2);
PACKET32H_TRANSPOSE_HELPER(0, 1);
PACKET32H_TRANSPOSE_HELPER(0, 2);
PACKET32H_TRANSPOSE_HELPER(0, 3);
PACKET32H_TRANSPOSE_HELPER(1, 2);
PACKET32H_TRANSPOSE_HELPER(1, 3);
PACKET32H_TRANSPOSE_HELPER(2, 3);
#undef PACKET32H_TRANSPOSE_HELPER
EIGEN_UNROLL_LOOP
for (int i = 0; i < 32; i++) {
a.packet[i] = _mm512_castsi512_ph(f[i]);
}
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 4>& a) {
__m512i p0, p1, p2, p3, t0, t1, t2, t3, a0, a1, a2, a3;
t0 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
t1 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
t2 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
t3 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
p0 = _mm512_unpacklo_epi32(t0, t2);
p1 = _mm512_unpackhi_epi32(t0, t2);
p2 = _mm512_unpacklo_epi32(t1, t3);
p3 = _mm512_unpackhi_epi32(t1, t3);
a0 = p0;
a1 = p1;
a2 = p2;
a3 = p3;
a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p1, 0), 1);
a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p0, 1), 0);
a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p2, 0), 2);
a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p0, 2), 0);
a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p3, 0), 3);
a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p0, 3), 0);
a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p2, 1), 2);
a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p1, 2), 1);
a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p3, 2), 3);
a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p2, 3), 2);
a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p3, 1), 3);
a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p1, 3), 1);
a.packet[0] = _mm512_castsi512_ph(a0);
a.packet[1] = _mm512_castsi512_ph(a1);
a.packet[2] = _mm512_castsi512_ph(a2);
a.packet[3] = _mm512_castsi512_ph(a3);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 16>& kernel) {
__m256i a = _mm256_castph_si256(kernel.packet[0]);
__m256i b = _mm256_castph_si256(kernel.packet[1]);
__m256i c = _mm256_castph_si256(kernel.packet[2]);
__m256i d = _mm256_castph_si256(kernel.packet[3]);
__m256i e = _mm256_castph_si256(kernel.packet[4]);
__m256i f = _mm256_castph_si256(kernel.packet[5]);
__m256i g = _mm256_castph_si256(kernel.packet[6]);
__m256i h = _mm256_castph_si256(kernel.packet[7]);
__m256i i = _mm256_castph_si256(kernel.packet[8]);
__m256i j = _mm256_castph_si256(kernel.packet[9]);
__m256i k = _mm256_castph_si256(kernel.packet[10]);
__m256i l = _mm256_castph_si256(kernel.packet[11]);
__m256i m = _mm256_castph_si256(kernel.packet[12]);
__m256i n = _mm256_castph_si256(kernel.packet[13]);
__m256i o = _mm256_castph_si256(kernel.packet[14]);
__m256i p = _mm256_castph_si256(kernel.packet[15]);
__m256i ab_07 = _mm256_unpacklo_epi16(a, b);
__m256i cd_07 = _mm256_unpacklo_epi16(c, d);
__m256i ef_07 = _mm256_unpacklo_epi16(e, f);
__m256i gh_07 = _mm256_unpacklo_epi16(g, h);
__m256i ij_07 = _mm256_unpacklo_epi16(i, j);
__m256i kl_07 = _mm256_unpacklo_epi16(k, l);
__m256i mn_07 = _mm256_unpacklo_epi16(m, n);
__m256i op_07 = _mm256_unpacklo_epi16(o, p);
__m256i ab_8f = _mm256_unpackhi_epi16(a, b);
__m256i cd_8f = _mm256_unpackhi_epi16(c, d);
__m256i ef_8f = _mm256_unpackhi_epi16(e, f);
__m256i gh_8f = _mm256_unpackhi_epi16(g, h);
__m256i ij_8f = _mm256_unpackhi_epi16(i, j);
__m256i kl_8f = _mm256_unpackhi_epi16(k, l);
__m256i mn_8f = _mm256_unpackhi_epi16(m, n);
__m256i op_8f = _mm256_unpackhi_epi16(o, p);
__m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
__m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
__m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
__m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
__m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
__m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
__m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
__m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
__m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
__m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
__m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
__m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
__m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
__m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
__m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
__m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
__m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
__m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
__m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
__m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
__m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
__m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
__m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
__m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
__m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
__m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
__m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
__m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
__m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
__m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
__m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
__m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
// NOTE: no unpacklo/hi instr in this case, so using permute instr.
__m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
__m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
__m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
__m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
__m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
__m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
__m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
__m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
__m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
__m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
__m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
__m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
__m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
__m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
__m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
__m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
kernel.packet[0] = _mm256_castsi256_ph(a_p_0);
kernel.packet[1] = _mm256_castsi256_ph(a_p_1);
kernel.packet[2] = _mm256_castsi256_ph(a_p_2);
kernel.packet[3] = _mm256_castsi256_ph(a_p_3);
kernel.packet[4] = _mm256_castsi256_ph(a_p_4);
kernel.packet[5] = _mm256_castsi256_ph(a_p_5);
kernel.packet[6] = _mm256_castsi256_ph(a_p_6);
kernel.packet[7] = _mm256_castsi256_ph(a_p_7);
kernel.packet[8] = _mm256_castsi256_ph(a_p_8);
kernel.packet[9] = _mm256_castsi256_ph(a_p_9);
kernel.packet[10] = _mm256_castsi256_ph(a_p_a);
kernel.packet[11] = _mm256_castsi256_ph(a_p_b);
kernel.packet[12] = _mm256_castsi256_ph(a_p_c);
kernel.packet[13] = _mm256_castsi256_ph(a_p_d);
kernel.packet[14] = _mm256_castsi256_ph(a_p_e);
kernel.packet[15] = _mm256_castsi256_ph(a_p_f);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 8>& kernel) {
EIGEN_ALIGN64 half in[8][16];
pstore<half>(in[0], kernel.packet[0]);
pstore<half>(in[1], kernel.packet[1]);
pstore<half>(in[2], kernel.packet[2]);
pstore<half>(in[3], kernel.packet[3]);
pstore<half>(in[4], kernel.packet[4]);
pstore<half>(in[5], kernel.packet[5]);
pstore<half>(in[6], kernel.packet[6]);
pstore<half>(in[7], kernel.packet[7]);
EIGEN_ALIGN64 half out[8][16];
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
out[i][j] = in[j][2 * i];
}
for (int j = 0; j < 8; ++j) {
out[i][j + 8] = in[j][2 * i + 1];
}
}
kernel.packet[0] = pload<Packet16h>(out[0]);
kernel.packet[1] = pload<Packet16h>(out[1]);
kernel.packet[2] = pload<Packet16h>(out[2]);
kernel.packet[3] = pload<Packet16h>(out[3]);
kernel.packet[4] = pload<Packet16h>(out[4]);
kernel.packet[5] = pload<Packet16h>(out[5]);
kernel.packet[6] = pload<Packet16h>(out[6]);
kernel.packet[7] = pload<Packet16h>(out[7]);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 4>& kernel) {
EIGEN_ALIGN64 half in[4][16];
pstore<half>(in[0], kernel.packet[0]);
pstore<half>(in[1], kernel.packet[1]);
pstore<half>(in[2], kernel.packet[2]);
pstore<half>(in[3], kernel.packet[3]);
EIGEN_ALIGN64 half out[4][16];
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
out[i][j] = in[j][4 * i];
}
for (int j = 0; j < 4; ++j) {
out[i][j + 4] = in[j][4 * i + 1];
}
for (int j = 0; j < 4; ++j) {
out[i][j + 8] = in[j][4 * i + 2];
}
for (int j = 0; j < 4; ++j) {
out[i][j + 12] = in[j][4 * i + 3];
}
}
kernel.packet[0] = pload<Packet16h>(out[0]);
kernel.packet[1] = pload<Packet16h>(out[1]);
kernel.packet[2] = pload<Packet16h>(out[2]);
kernel.packet[3] = pload<Packet16h>(out[3]);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8h, 8>& kernel) {
__m128i a = _mm_castph_si128(kernel.packet[0]);
__m128i b = _mm_castph_si128(kernel.packet[1]);
__m128i c = _mm_castph_si128(kernel.packet[2]);
__m128i d = _mm_castph_si128(kernel.packet[3]);
__m128i e = _mm_castph_si128(kernel.packet[4]);
__m128i f = _mm_castph_si128(kernel.packet[5]);
__m128i g = _mm_castph_si128(kernel.packet[6]);
__m128i h = _mm_castph_si128(kernel.packet[7]);
__m128i a03b03 = _mm_unpacklo_epi16(a, b);
__m128i c03d03 = _mm_unpacklo_epi16(c, d);
__m128i e03f03 = _mm_unpacklo_epi16(e, f);
__m128i g03h03 = _mm_unpacklo_epi16(g, h);
__m128i a47b47 = _mm_unpackhi_epi16(a, b);
__m128i c47d47 = _mm_unpackhi_epi16(c, d);
__m128i e47f47 = _mm_unpackhi_epi16(e, f);
__m128i g47h47 = _mm_unpackhi_epi16(g, h);
__m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03);
__m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03);
__m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03);
__m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03);
__m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47);
__m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47);
__m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47);
__m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47);
__m128i a0b0c0d0e0f0g0h0 = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01);
__m128i a1b1c1d1e1f1g1h1 = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01);
__m128i a2b2c2d2e2f2g2h2 = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23);
__m128i a3b3c3d3e3f3g3h3 = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23);
__m128i a4b4c4d4e4f4g4h4 = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45);
__m128i a5b5c5d5e5f5g5h5 = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45);
__m128i a6b6c6d6e6f6g6h6 = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67);
__m128i a7b7c7d7e7f7g7h7 = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67);
kernel.packet[0] = _mm_castsi128_ph(a0b0c0d0e0f0g0h0);
kernel.packet[1] = _mm_castsi128_ph(a1b1c1d1e1f1g1h1);
kernel.packet[2] = _mm_castsi128_ph(a2b2c2d2e2f2g2h2);
kernel.packet[3] = _mm_castsi128_ph(a3b3c3d3e3f3g3h3);
kernel.packet[4] = _mm_castsi128_ph(a4b4c4d4e4f4g4h4);
kernel.packet[5] = _mm_castsi128_ph(a5b5c5d5e5f5g5h5);
kernel.packet[6] = _mm_castsi128_ph(a6b6c6d6e6f6g6h6);
kernel.packet[7] = _mm_castsi128_ph(a7b7c7d7e7f7g7h7);
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8h, 4>& kernel) {
EIGEN_ALIGN32 Eigen::half in[4][8];
pstore<Eigen::half>(in[0], kernel.packet[0]);
pstore<Eigen::half>(in[1], kernel.packet[1]);
pstore<Eigen::half>(in[2], kernel.packet[2]);
pstore<Eigen::half>(in[3], kernel.packet[3]);
EIGEN_ALIGN32 Eigen::half out[4][8];
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
out[i][j] = in[j][2 * i];
}
for (int j = 0; j < 4; ++j) {
out[i][j + 4] = in[j][2 * i + 1];
}
}
kernel.packet[0] = pload<Packet8h>(out[0]);
kernel.packet[1] = pload<Packet8h>(out[1]);
kernel.packet[2] = pload<Packet8h>(out[2]);
kernel.packet[3] = pload<Packet8h>(out[3]);
}
// preverse
template <>
EIGEN_STRONG_INLINE Packet32h preverse(const Packet32h& a) {
return _mm512_permutexvar_ph(_mm512_set_epi16(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31),
a);
}
template <>
EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a) {
__m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
return _mm256_castsi256_ph(_mm256_insertf128_si256(
_mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castph_si256(a), 1), m)),
_mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castph_si256(a), 0), m), 1));
}
template <>
EIGEN_STRONG_INLINE Packet8h preverse(const Packet8h& a) {
__m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
return _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a), m));
}
// pscatter
template <>
EIGEN_STRONG_INLINE void pscatter<half, Packet32h>(half* to, const Packet32h& from, Index stride) {
EIGEN_ALIGN64 half aux[32];
pstore(aux, from);
EIGEN_UNROLL_LOOP
for (int i = 0; i < 32; i++) {
to[stride * i] = aux[i];
}
}
template <>
EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Packet16h& from, Index stride) {
EIGEN_ALIGN64 half aux[16];
pstore(aux, from);
to[stride * 0] = aux[0];
to[stride * 1] = aux[1];
to[stride * 2] = aux[2];
to[stride * 3] = aux[3];
to[stride * 4] = aux[4];
to[stride * 5] = aux[5];
to[stride * 6] = aux[6];
to[stride * 7] = aux[7];
to[stride * 8] = aux[8];
to[stride * 9] = aux[9];
to[stride * 10] = aux[10];
to[stride * 11] = aux[11];
to[stride * 12] = aux[12];
to[stride * 13] = aux[13];
to[stride * 14] = aux[14];
to[stride * 15] = aux[15];
}
template <>
EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride) {
EIGEN_ALIGN32 Eigen::half aux[8];
pstore(aux, from);
to[stride * 0] = aux[0];
to[stride * 1] = aux[1];
to[stride * 2] = aux[2];
to[stride * 3] = aux[3];
to[stride * 4] = aux[4];
to[stride * 5] = aux[5];
to[stride * 6] = aux[6];
to[stride * 7] = aux[7];
}
// pgather
template <>
EIGEN_STRONG_INLINE Packet32h pgather<Eigen::half, Packet32h>(const Eigen::half* from, Index stride) {
return _mm512_set_ph(from[31 * stride].x, from[30 * stride].x, from[29 * stride].x, from[28 * stride].x,
from[27 * stride].x, from[26 * stride].x, from[25 * stride].x, from[24 * stride].x,
from[23 * stride].x, from[22 * stride].x, from[21 * stride].x, from[20 * stride].x,
from[19 * stride].x, from[18 * stride].x, from[17 * stride].x, from[16 * stride].x,
from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x,
from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x,
from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
}
template <>
EIGEN_STRONG_INLINE Packet16h pgather<Eigen::half, Packet16h>(const Eigen::half* from, Index stride) {
return _mm256_set_ph(from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x,
from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x,
from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
}
template <>
EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride) {
return _mm_set_ph(from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, from[3 * stride].x,
from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_PACKET_MATH_FP16_AVX512_H
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2022 Intel Corporation
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
#define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
#if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
#define EIGEN_USE_AVX512_TRSM_KERNELS 1
#endif
// TRSM kernels currently unconditionally rely on malloc with AVX512.
// Disable them if malloc is explicitly disabled at compile-time.
#ifdef EIGEN_NO_MALLOC
#undef EIGEN_USE_AVX512_TRSM_KERNELS
#define EIGEN_USE_AVX512_TRSM_KERNELS 0
#endif
#if EIGEN_USE_AVX512_TRSM_KERNELS
#if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
#endif
#if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
#endif
#else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
#endif
// Need this for some std::min calls.
#ifdef min
#undef min
#endif
namespace Eigen {
namespace internal {
#define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
#define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code.
#define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
#define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
#define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
typedef Packet16f vecFullFloat;
typedef Packet8d vecFullDouble;
typedef Packet8f vecHalfFloat;
typedef Packet4d vecHalfDouble;
// Compile-time unrolls are implemented here.
// Note: this depends on macros and typedefs above.
#include "TrsmUnrolls.inc"
#if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
/**
* For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
* is faster than the packed versions in TriangularSolverMatrix.h.
*
* The current heuristic is based on having having all arrays used in the largest gemm-update
* in triSolve fit in roughly L2Cap (percentage) of the L2 cache. These cutoffs are a bit conservative and could be
* larger for some trsm cases.
* The formula:
*
* (L*M + M*N + L*N)*sizeof(Scalar) < L2Cache*L2Cap
*
* L = number of rows to solve at a time
* N = number of rhs
* M = Dimension of triangular matrix
*
*/
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
#endif
#if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
#if EIGEN_USE_AVX512_TRSM_R_KERNELS
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
#endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
#endif
#if EIGEN_USE_AVX512_TRSM_L_KERNELS
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
#endif
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
#else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
#endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
template <typename Scalar>
int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
const int64_t U3 = 3 * packet_traits<Scalar>::size;
const int64_t MaxNb = 5 * U3;
int64_t Nb = std::min(MaxNb, N);
double cutoff_d =
(((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
}
#else // !(EIGEN_USE_AVX512_TRSM_KERNELS) || !(EIGEN_COMP_CLANG != 0)
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
#endif
/**
* Used by gemmKernel for the case A/B row-major and C col-major.
*/
template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, Scalar *C_arr,
int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
EIGEN_UNUSED_VARIABLE(remN_);
EIGEN_UNUSED_VARIABLE(remM_);
using urolls = unrolls::trans<Scalar>;
constexpr int64_t U3 = urolls::PacketSize * 3;
constexpr int64_t U2 = urolls::PacketSize * 2;
constexpr int64_t U1 = urolls::PacketSize * 1;
static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize");
static_assert(unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
urolls::template transpose<unrollN, 0>(zmm);
EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose<unrollN, 2>(zmm);
EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose<unrollN, 1>(zmm);
static_assert((remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1");
EIGEN_IF_CONSTEXPR(!remN) {
urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
EIGEN_IF_CONSTEXPR(unrollN > U1) {
constexpr int64_t unrollN_ = std::min(unrollN - U1, U1);
urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1 * LDC, LDC, zmm, remM_);
}
EIGEN_IF_CONSTEXPR(unrollN > U2) {
constexpr int64_t unrollN_ = std::min(unrollN - U2, U1);
urolls::template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2 * LDC, LDC, zmm, remM_);
}
}
else {
EIGEN_IF_CONSTEXPR((std::is_same<Scalar, float>::value)) {
// Note: without "if constexpr" this section of code will also be
// parsed by the compiler so each of the storeC will still be instantiated.
// We use enable_if in aux_storeC to set it to an empty function for
// these cases.
if (remN_ == 15)
urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 14)
urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 13)
urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 12)
urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 11)
urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 10)
urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 9)
urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 8)
urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 7)
urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 6)
urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 5)
urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 4)
urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 3)
urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 2)
urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 1)
urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
}
else {
if (remN_ == 7)
urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 6)
urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 5)
urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 4)
urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 3)
urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 2)
urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
else if (remN_ == 1)
urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
}
}
}
/**
* GEMM like operation for trsm panel updates.
* Computes: C -= A*B
* K must be multiple of 4.
*
* Unrolls used are {1,2,4,8}x{U1,U2,U3};
* For good performance we want K to be large with M/N relatively small, but also large enough
* to use the {8,U3} unroll block.
*
* isARowMajor: is A_arr row-major?
* isCRowMajor: is C_arr row-major? (B_arr is assumed to be row-major).
* isAdd: C += A*B or C -= A*B (used by trsm)
* handleKRem: Handle arbitrary K? This is not needed for trsm.
*/
template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
int64_t LDC) {
using urolls = unrolls::gemm<Scalar, isAdd>;
constexpr int64_t U3 = urolls::PacketSize * 3;
constexpr int64_t U2 = urolls::PacketSize * 2;
constexpr int64_t U1 = urolls::PacketSize * 1;
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
int64_t N_ = (N / U3) * U3;
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL;
int64_t j = 0;
for (; j < N_; j += U3) {
constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 3;
int64_t i = 0;
for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, false, false>(zmm, &C_arr[i + j * LDC], LDC);
}
}
if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<3, 4>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
B_t, A_t, LDB, LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
}
i += 4;
}
if (M - i >= 2) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<3, 2>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
B_t, A_t, LDB, LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
}
i += 2;
}
if (M - i > 0) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<3, 1>(zmm);
{
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
}
}
}
}
if (N - j >= U2) {
constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 2;
int64_t i = 0;
for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, false, false>(zmm, &C_arr[i + j * LDC], LDC);
}
}
if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<2, 4>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
}
i += 4;
}
if (M - i >= 2) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<2, 2>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
}
i += 2;
}
if (M - i > 0) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<2, 1>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
}
}
j += U2;
}
if (N - j >= U1) {
constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
int64_t i = 0;
for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, false>(zmm, &C_arr[i + j * LDC], LDC);
}
}
if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 4>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
}
i += 4;
}
if (M - i >= 2) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 2>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
}
i += 2;
}
if (M - i > 0) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 1>(zmm);
{
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
LDA, zmm);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
urolls::template storeC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
}
}
}
j += U1;
}
if (N - j > 0) {
constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
int64_t i = 0;
for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, true>(zmm, &C_arr[i + j * LDC], LDC, 0, N - j);
}
}
if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 4>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
B_t, A_t, LDB, LDA, zmm, N - j);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
urolls::template storeC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 4, N - j);
}
i += 4;
}
if (M - i >= 2) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 2>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
B_t, A_t, LDB, LDA, zmm, N - j);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
urolls::template storeC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 2, N - j);
}
i += 2;
}
if (M - i > 0) {
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
Scalar *B_t = &B_arr[0 * LDB + j];
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
urolls::template setzero<1, 1>(zmm);
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
B_t, A_t, LDB, LDA, zmm, N - j);
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
}
EIGEN_IF_CONSTEXPR(handleKRem) {
for (int64_t k = K_; k < K; k++) {
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
N - j);
B_t += LDB;
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
else A_t += LDA;
}
}
EIGEN_IF_CONSTEXPR(isCRowMajor) {
urolls::template updateC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
urolls::template storeC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
}
else {
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 1, N - j);
}
}
}
}
/**
* Triangular solve kernel with A on left with K number of rhs. dim(A) = unrollM
*
* unrollM: dimension of A matrix (triangular matrix). unrollM should be <= EIGEN_AVX_MAX_NUM_ROW
* isFWDSolve: is forward solve?
* isUnitDiag: is the diagonal of A all ones?
* The B matrix (RHS) is assumed to be row-major
*/
template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
using urolls = unrolls::trsm<Scalar>;
constexpr int64_t U3 = urolls::PacketSize * 3;
constexpr int64_t U2 = urolls::PacketSize * 2;
constexpr int64_t U1 = urolls::PacketSize * 1;
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> RHSInPacket;
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> AInPacket;
int64_t k = 0;
while (K - k >= U3) {
urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
AInPacket);
urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
k += U3;
}
if (K - k >= U2) {
urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
AInPacket);
urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
k += U2;
}
if (K - k >= U1) {
urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
AInPacket);
urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
k += U1;
}
if (K - k > 0) {
// Handle remaining number of RHS
urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
AInPacket);
urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
}
}
/**
* Triangular solve routine with A on left and dimension of at most L with K number of rhs. This is essentially
* a wrapper for triSolveMicrokernel for M = {1,2,3,4,5,6,7,8}.
*
* isFWDSolve: is forward solve?
* isUnitDiag: is the diagonal of A all ones?
* The B matrix (RHS) is assumed to be row-major
*/
template <typename Scalar, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB) {
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
if (M == 8)
triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
else if (M == 7)
triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
else if (M == 6)
triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
else if (M == 5)
triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
else if (M == 4)
triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
else if (M == 3)
triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
else if (M == 2)
triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
else if (M == 1)
triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
return;
}
/**
* This routine is used to copy B to/from a temporary array (row-major) for cases where B is column-major.
*
* toTemp: true => copy to temporary array, false => copy from temporary array
* remM: true = need to handle remainder values for M (M < EIGEN_AVX_MAX_NUM_ROW)
*
*/
template <typename Scalar, bool toTemp = true, bool remM = false>
EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_,
int64_t remM_ = 0) {
EIGEN_UNUSED_VARIABLE(remM_);
using urolls = unrolls::transB<Scalar>;
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> ymm;
constexpr int64_t U3 = urolls::PacketSize * 3;
constexpr int64_t U2 = urolls::PacketSize * 2;
constexpr int64_t U1 = urolls::PacketSize * 1;
int64_t K_ = K / U3 * U3;
int64_t k = 0;
for (; k < K_; k += U3) {
urolls::template transB_kernel<U3, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += U3;
}
if (K - k >= U2) {
urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += U2;
k += U2;
}
if (K - k >= U1) {
urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += U1;
k += U1;
}
EIGEN_IF_CONSTEXPR(U1 > 8) {
// Note: without "if constexpr" this section of code will also be
// parsed by the compiler so there is an additional check in {load/store}BBlock
// to make sure the counter is not non-negative.
if (K - k >= 8) {
urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += 8;
k += 8;
}
}
EIGEN_IF_CONSTEXPR(U1 > 4) {
// Note: without "if constexpr" this section of code will also be
// parsed by the compiler so there is an additional check in {load/store}BBlock
// to make sure the counter is not non-negative.
if (K - k >= 4) {
urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += 4;
k += 4;
}
}
if (K - k >= 2) {
urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += 2;
k += 2;
}
if (K - k >= 1) {
urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
B_temp += 1;
k += 1;
}
}
/**
* Main triangular solve driver
*
* Triangular solve with A on the left.
* Scalar: Scalar precision, only float/double is supported.
* isARowMajor: is A row-major?
* isBRowMajor: is B row-major?
* isFWDSolve: is this forward solve or backward (true => forward)?
* isUnitDiag: is diagonal of A unit or nonunit (true => A has unit diagonal)?
*
* M: dimension of A
* numRHS: number of right hand sides (coincides with K dimension for gemm updates)
*
* Here are the mapping between the different TRSM cases (col-major) and triSolve:
*
* LLN (left , lower, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=true
* LUT (left , upper, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=true
* LUN (left , upper, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=false
* LLT (left , lower, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=false
* RUN (right, upper, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=true
* RLT (right, lower, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=true
* RUT (right, upper, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=false
* RLN (right, lower, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=false
*
* Note: For RXX cases M,numRHS should be swapped.
*
*/
template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
bool isUnitDiag = false>
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
constexpr int64_t psize = packet_traits<Scalar>::size;
/**
* The values for kB, numM were determined experimentally.
* kB: Number of RHS we process at a time.
* numM: number of rows of B we will store in a temporary array (see below.) This should be a multiple of L.
*
* kB was determined by initially setting kB = numRHS and benchmarking triSolve (TRSM-RUN case)
* performance with M=numRHS.
* It was observed that performance started to drop around M=numRHS=240. This is likely machine dependent.
*
* numM was chosen "arbitrarily". It should be relatively small so B_temp is not too large, but it should be
* large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
* determine optimal values for numM.
*/
constexpr int64_t kB = (3 * psize) * 5; // 5*U3
constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
int64_t sizeBTemp = 0;
Scalar *B_temp = NULL;
EIGEN_IF_CONSTEXPR(!isBRowMajor) {
/**
* If B is col-major, we copy it to a fixed-size temporary array of size at most ~numM*kB and
* transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
* The updated row-major copy of B is reused in the GEMM updates.
*/
sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
}
EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
for (int64_t k = 0; k < numRHS; k += kB) {
int64_t bK = numRHS - k > kB ? kB : numRHS - k;
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
// bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS
// instead of bK RHS in triSolveKernelLxK.
int64_t bkL = ((bK + (EIGEN_AVX_MAX_NUM_ROW - 1)) / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
const int64_t numScalarPerCache = 64 / sizeof(Scalar);
// Leading dimension of B_temp, will be a multiple of the cache line size.
int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache;
int64_t offsetBTemp = 0;
for (int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
EIGEN_IF_CONSTEXPR(!isBRowMajor) {
int64_t indA_i = isFWDSolve ? i : M - 1 - i;
int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW);
int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW * LDT - offsetBTemp;
int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
// Copy values from B to B_temp.
copyBToRowMajor<Scalar, true, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
// Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major)
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
// Copy values from B_temp back to B. B_temp will be reused in gemm call below.
copyBToRowMajor<Scalar, false, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
offsetBTemp += EIGEN_AVX_MAX_NUM_ROW * LDT;
}
else {
int64_t ind = isFWDSolve ? i : M - 1 - i;
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
&A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind * LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB);
}
if (i + EIGEN_AVX_MAX_NUM_ROW < M_) {
/**
* For the GEMM updates, we want "K" (K=i+8 in this case) to be large as soon as possible
* to reuse the accumulators in GEMM as much as possible. So we only update 8xbK blocks of
* B as follows:
*
* A B
* __
* |__|__ |__|
* |__|__|__ |__|
* |__|__|__|__ |__|
* |********|__| |**|
*/
EIGEN_IF_CONSTEXPR(isBRowMajor) {
int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB);
}
else {
if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
/**
* Similar idea as mentioned above, but here we are limited by the number of updated values of B
* that can be stored (row-major) in B_temp.
*
* If there is not enough space to store the next batch of 8xbK of B in B_temp, we call GEMM
* update and partially update the remaining old values of B which depends on the new values
* of B stored in B_temp. These values are then no longer needed and can be overwritten.
*/
int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
offsetBTemp = 0;
gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
} else {
/**
* If there is enough space in B_temp, we only update the next 8xbK values of B.
*/
int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
}
}
}
}
// Handle M remainder..
int64_t bM = M - M_;
if (bM > 0) {
if (M_ > 0) {
EIGEN_IF_CONSTEXPR(isBRowMajor) {
int64_t indA_i = isFWDSolve ? M_ : 0;
int64_t indA_j = isFWDSolve ? 0 : bM;
int64_t indB_i = isFWDSolve ? 0 : bM;
int64_t indB_i2 = isFWDSolve ? M_ : 0;
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
bK, M_, LDA, LDB, LDB);
}
else {
int64_t indA_i = isFWDSolve ? M_ : 0;
int64_t indA_j = isFWDSolve ? gemmOff : bM;
int64_t indB_i = isFWDSolve ? M_ : 0;
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
M_ - gemmOff, LDA, LDT, LDB);
}
}
EIGEN_IF_CONSTEXPR(!isBRowMajor) {
int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_;
int64_t indB_i = isFWDSolve ? M_ : 0;
int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
B_temp + offB_1, bM, bkL, LDA, bkL);
copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
}
else {
int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
B_arr + k + ind * LDB, bM, bK, LDA, LDB);
}
}
}
EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
}
// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
#if (EIGEN_USE_AVX512_TRSM_KERNELS)
#if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
bool Specialized>
struct trsmKernelR;
template <typename Index, int Mode, int TriStorageOrder>
struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true> {
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
Index otherStride);
};
template <typename Index, int Mode, int TriStorageOrder>
struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true> {
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
Index otherStride);
};
template <typename Index, int Mode, int TriStorageOrder>
EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
Index otherStride) {
EIGEN_UNUSED_VARIABLE(otherIncr);
#ifdef EIGEN_RUNTIME_NO_MALLOC
if (!is_malloc_allowed()) {
trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
return;
}
#endif
triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
}
template <typename Index, int Mode, int TriStorageOrder>
EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
Index otherStride) {
EIGEN_UNUSED_VARIABLE(otherIncr);
#ifdef EIGEN_RUNTIME_NO_MALLOC
if (!is_malloc_allowed()) {
trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
return;
}
#endif
triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
}
#endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
// These trsm kernels require temporary memory allocation
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
bool Specialized = true>
struct trsmKernelL;
template <typename Index, int Mode, int TriStorageOrder>
struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true> {
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
Index otherStride);
};
template <typename Index, int Mode, int TriStorageOrder>
struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true> {
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
Index otherStride);
};
template <typename Index, int Mode, int TriStorageOrder>
EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
Index otherStride) {
EIGEN_UNUSED_VARIABLE(otherIncr);
#ifdef EIGEN_RUNTIME_NO_MALLOC
if (!is_malloc_allowed()) {
trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
return;
}
#endif
triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
}
template <typename Index, int Mode, int TriStorageOrder>
EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
Index otherStride) {
EIGEN_UNUSED_VARIABLE(otherIncr);
#ifdef EIGEN_RUNTIME_NO_MALLOC
if (!is_malloc_allowed()) {
trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
return;
}
#endif
triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
}
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
#endif // EIGEN_USE_AVX512_TRSM_KERNELS
} // namespace internal
} // namespace Eigen
#endif // EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2022 Intel Corporation
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
template <bool isARowMajor = true>
EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
else return i + j * LDA;
}
/**
* This namespace contains various classes used to generate compile-time unrolls which are
* used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested
* for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion
*
* Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop.
*
* for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++)
* for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ)
* func(startI,startJ) startJ = (startC)%(endJ)
* func(...)
*
* The 1-D loop can be unrolled recursively by using enable_if and defining an auxiliary function
* with a template parameter used as a counter.
*
* template <endI, endJ, counter>
* std::enable_if_t<(counter <= 0)> <---- tail case.
* aux_func {}
*
* template <endI, endJ, counter>
* std::enable_if_t<(counter > 0)> <---- actual for-loop
* aux_func {
* startC = endI*endJ - counter
* startI = (startC)/(endJ)
* startJ = (startC)%(endJ)
* func(startI, startJ)
* aux_func<endI, endJ, counter-1>()
* }
*
* Note: Additional wrapper functions are provided for aux_func which hides the counter template
* parameter since counter usually depends on endI, endJ, etc...
*
* Conventions:
* 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++))
*
* 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for
* handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW)
*/
namespace unrolls {
template <int64_t N>
EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
else EIGEN_IF_CONSTEXPR(N == 8) {
return 0xFF >> (8 - m);
}
else EIGEN_IF_CONSTEXPR(N == 4) {
return 0x0F >> (4 - m);
}
return 0;
}
template <typename Packet>
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet, 8> &kernel);
template <>
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8> &kernel) {
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
__m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
__m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
__m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
__m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
__m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
__m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
kernel.packet[0] = T0;
kernel.packet[1] = T1;
kernel.packet[2] = T2;
kernel.packet[3] = T3;
kernel.packet[4] = T4;
kernel.packet[5] = T5;
kernel.packet[6] = T6;
kernel.packet[7] = T7;
}
template <>
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet8d, 8> &kernel) {
ptranspose(kernel);
}
/***
* Unrolls for transposed C stores
*/
template <typename Scalar>
class trans {
public:
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/***********************************
* Auxiliary Functions for:
* - storeC
***********************************
*/
/**
* aux_storeC
*
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
*
* (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC
*
**/
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
EIGEN_IF_CONSTEXPR(remM) {
pstoreu<Scalar>(
C_arr + LDC * startN,
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
}
else {
pstoreu<Scalar>(C_arr + LDC * startN,
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
}
}
else { // This block is only needed for fp32 case
// Reinterpret as __m512 for _mm512_shuffle_f32x4
vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]);
// Swap lower and upper half of avx register.
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] =
preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
EIGEN_IF_CONSTEXPR(remM) {
pstoreu<Scalar>(
C_arr + LDC * startN,
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
preinterpret<vecHalf>(
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
}
else {
pstoreu<Scalar>(
C_arr + LDC * startN,
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
preinterpret<vecHalf>(
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
}
}
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
}
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
EIGEN_UNUSED_VARIABLE(C_arr);
EIGEN_UNUSED_VARIABLE(LDC);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(remM_);
}
template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t remM_ = 0) {
aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
}
/**
* Transposes LxunrollN row major block of matrices stored `EIGEN_AVX_MAX_NUM_ACC` zmm registers to
* "unrollN"xL ymm registers to be stored col-major into C.
*
* For 8x48, the 8x48 block (row-major) is stored in zmm as follows:
*
* ```
* row0: zmm0 zmm1 zmm2
* row1: zmm3 zmm4 zmm5
* .
* .
* row7: zmm21 zmm22 zmm23
*
* For 8x32, the 8x32 block (row-major) is stored in zmm as follows:
*
* row0: zmm0 zmm1
* row1: zmm2 zmm3
* .
* .
* row7: zmm14 zmm15
* ```
*
* In general we will have {1,2,3} groups of avx registers each of size
* `EIGEN_AVX_MAX_NUM_ROW`. packetIndexOffset is used to select which "block" of
* avx registers are being transposed.
*/
template <int64_t unrollN, int64_t packetIndexOffset>
static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
constexpr int64_t zmmStride = unrollN / PacketSize;
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
trans8x8blocks(r);
zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
}
};
/**
* Unrolls for copyBToRowMajor
*
* Idea:
* 1) Load a block of right-hand sides to registers (using loadB).
* 2) Convert the block from column-major to row-major (transposeLxL)
* 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false).
*
* We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are
* used as temps for transposing.
*
* Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
*/
template <typename Scalar>
class transB {
public:
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/***********************************
* Auxiliary Functions for:
* - loadB
* - storeB
* - loadBBlock
* - storeBBlock
***********************************
*/
/**
* aux_loadB
*
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
**/
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(remM) {
ymm.packet[packetIndexOffset + startN] =
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
}
else {
EIGEN_IF_CONSTEXPR(remN_ == 0) {
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
}
else ymm.packet[packetIndexOffset + startN] =
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
}
aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
}
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(ymm);
EIGEN_UNUSED_VARIABLE(remM_);
}
/**
* aux_storeB
*
* 1-D unroll
* for(startN = 0; startN < endN; startN++)
**/
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(remK || remM) {
pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
}
else {
pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
}
aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
}
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(ymm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_loadBBlock
*
* 1-D unroll
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse;
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(B_temp);
EIGEN_UNUSED_VARIABLE(LDB_);
EIGEN_UNUSED_VARIABLE(ymm);
EIGEN_UNUSED_VARIABLE(remM_);
}
/**
* aux_storeBBlock
*
* 1-D unroll
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
**/
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
constexpr int64_t counterReverse = endN - counter;
constexpr int64_t startN = counterReverse;
EIGEN_IF_CONSTEXPR(toTemp) {
transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
}
else {
transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
ymm, remM_);
}
aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(B_temp);
EIGEN_UNUSED_VARIABLE(LDB_);
EIGEN_UNUSED_VARIABLE(ymm);
EIGEN_UNUSED_VARIABLE(remM_);
}
/********************************************************
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
}
template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t rem_ = 0) {
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
}
template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
else {
aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
}
template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
template <int64_t packetIndexOffset>
static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
r.packet[0] = ymm.packet[packetIndexOffset + 0];
r.packet[1] = ymm.packet[packetIndexOffset + 1];
r.packet[2] = ymm.packet[packetIndexOffset + 2];
r.packet[3] = ymm.packet[packetIndexOffset + 3];
r.packet[4] = ymm.packet[packetIndexOffset + 4];
r.packet[5] = ymm.packet[packetIndexOffset + 5];
r.packet[6] = ymm.packet[packetIndexOffset + 6];
r.packet[7] = ymm.packet[packetIndexOffset + 7];
ptranspose(r);
ymm.packet[packetIndexOffset + 0] = r.packet[0];
ymm.packet[packetIndexOffset + 1] = r.packet[1];
ymm.packet[packetIndexOffset + 2] = r.packet[2];
ymm.packet[packetIndexOffset + 3] = r.packet[3];
ymm.packet[packetIndexOffset + 4] = r.packet[4];
ymm.packet[packetIndexOffset + 5] = r.packet[5];
ymm.packet[packetIndexOffset + 6] = r.packet[6];
ymm.packet[packetIndexOffset + 7] = r.packet[7];
}
template <int64_t unrollN, bool toTemp, bool remM>
static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
int64_t remM_ = 0) {
constexpr int64_t U3 = PacketSize * 3;
constexpr int64_t U2 = PacketSize * 2;
constexpr int64_t U1 = PacketSize * 1;
/**
* Unrolls needed for each case:
* - AVX512 fp32 48 32 16 8 4 2 1
* - AVX512 fp64 24 16 8 4 2 1
*
* For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up.
*/
EIGEN_IF_CONSTEXPR(unrollN == U3) {
// load LxU3 B col major, transpose LxU3 row major
constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3);
transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
ymm, remM_);
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
ymm, remM_);
}
}
else EIGEN_IF_CONSTEXPR(unrollN == U2) {
// load LxU2 B col major, transpose LxU2 row major
constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2);
transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
&B_temp[maxUBlock], LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
&B_temp[maxUBlock], LDB_, ymm, remM_);
}
}
else EIGEN_IF_CONSTEXPR(unrollN == U1) {
// load LxU1 B col major, transpose LxU1 row major
transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
// load Lx4 B col major, transpose Lx4 row major
transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
// load Lx4 B col major, transpose Lx4 row major
transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
else EIGEN_IF_CONSTEXPR(unrollN == 2) {
// load Lx2 B col major, transpose Lx2 row major
transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
else EIGEN_IF_CONSTEXPR(unrollN == 1) {
// load Lx1 B col major, transpose Lx1 row major
transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
transB::template transposeLxL<0>(ymm);
transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
}
}
};
/**
* Unrolls for triSolveKernel
*
* Idea:
* 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS).
* 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix)
* stored in AInPacket (using triSolveMicroKernel).
* 3) Store final results (in avx registers) back into memory (using storeRHS).
*
* RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most
* EIGEN_AVX_MAX_NUM_ROW registers.
*/
template <typename Scalar>
class trsm {
public:
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/***********************************
* Auxiliary Functions for:
* - loadRHS
* - storeRHS
* - divRHSByDiag
* - updateRHS
* - triSolveMicroKernel
************************************/
/**
* aux_loadRHS
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++)
**/
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
constexpr int64_t counterReverse = endM * endK - counter;
constexpr int64_t startM = counterReverse / (endK);
constexpr int64_t startK = counterReverse % endK;
constexpr int64_t packetIndex = startM * endK + startK;
constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
EIGEN_IF_CONSTEXPR(krem) {
RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
}
else {
RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
}
aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
}
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(rem);
}
/**
* aux_storeRHS
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++)
**/
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
constexpr int64_t counterReverse = endM * endK - counter;
constexpr int64_t startM = counterReverse / (endK);
constexpr int64_t startK = counterReverse % endK;
constexpr int64_t packetIndex = startM * endK + startK;
constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
EIGEN_IF_CONSTEXPR(krem) {
pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
}
else {
pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
}
aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
}
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
EIGEN_UNUSED_VARIABLE(B_arr);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(rem);
}
/**
* aux_divRHSByDiag
*
* currM may be -1, (currM >=0) in enable_if checks for this
*
* 1-D unroll
* for(startK = 0; startK < endK; startK++)
**/
template <int64_t currM, int64_t endK, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
constexpr int64_t counterReverse = endK - counter;
constexpr int64_t startK = counterReverse;
constexpr int64_t packetIndex = currM * endK + startK;
RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
}
template <int64_t currM, int64_t endK, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(AInPacket);
}
/**
* aux_updateRHS
*
* 2-D unroll
* for(startM = initM; startM < endM; startM++)
* for(startK = 0; startK < endK; startK++)
**/
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
int64_t counter, int64_t currentM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
constexpr int64_t counterReverse = (endM - initM) * endK - counter;
constexpr int64_t startM = initM + counterReverse / (endK);
constexpr int64_t startK = counterReverse % endK;
// For each row of A, first update all corresponding RHS
constexpr int64_t packetIndex = startM * endK + startK;
EIGEN_IF_CONSTEXPR(currentM > 0) {
RHSInPacket.packet[packetIndex] =
pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
RHSInPacket.packet[packetIndex]);
}
EIGEN_IF_CONSTEXPR(startK == endK - 1) {
// Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
// If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
// This will be used in divRHSByDiag
EIGEN_IF_CONSTEXPR(isFWDSolve)
AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
}
else {
// Broadcast next off diagonal element of A
EIGEN_IF_CONSTEXPR(isFWDSolve)
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
}
}
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
A_arr, LDA, RHSInPacket, AInPacket);
}
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
int64_t counter, int64_t currentM>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
EIGEN_UNUSED_VARIABLE(A_arr);
EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(AInPacket);
}
/**
* aux_triSolverMicroKernel
*
* 1-D unroll
* for(startM = 0; startM < endM; startM++)
**/
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
constexpr int64_t counterReverse = endM - counter;
constexpr int64_t startM = counterReverse;
constexpr int64_t currentM = startM;
// Divides the right-hand side in row startM, by digonal value of A
// broadcasted to AInPacket.packet[startM-1] in the previous iteration.
//
// Without "if constexpr" the compiler instantiates the case <-1, numK>
// this is handled with enable_if to prevent out-of-bound warnings
// from the compiler
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
// After division, the rhs corresponding to subsequent rows of A can be partially updated
// We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
// to be used in the next iteration.
trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
AInPacket);
// Handle division for the RHS corresponding to the final row of A.
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
AInPacket);
}
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
EIGEN_UNUSED_VARIABLE(A_arr);
EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(RHSInPacket);
EIGEN_UNUSED_VARIABLE(AInPacket);
}
/********************************************************
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
/**
* Load endMxendK block of B to RHSInPacket
* Masked loads are used for cases where endK is not a multiple of PacketSize
*/
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
}
/**
* Load endMxendK block of B to RHSInPacket
* Masked loads are used for cases where endK is not a multiple of PacketSize
*/
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
}
/**
* Only used if Triangular matrix has non-unit diagonal values
*/
template <int64_t currM, int64_t endK>
static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
}
/**
* Update right-hand sides (stored in avx registers)
* Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
**/
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
int64_t currentM>
static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
A_arr, LDA, RHSInPacket, AInPacket);
}
/**
* endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW
* numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3.
* isFWDSolve: true => forward substitution, false => backwards substitution
* isUnitDiag: true => triangular matrix has unit diagonal.
*/
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
static_assert(numK >= 1 && numK <= 3, "numK out of range");
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
}
};
/**
* Unrolls for gemm kernel
*
* isAdd: true => C += A*B, false => C -= A*B
*/
template <typename Scalar, bool isAdd>
class gemm {
public:
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
/***********************************
* Auxiliary Functions for:
* - setzero
* - updateC
* - storeC
* - startLoadB
* - triSolveMicroKernel
************************************/
/**
* aux_setzero
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
**/
template <int64_t endM, int64_t endN, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
constexpr int64_t counterReverse = endM * endN - counter;
constexpr int64_t startM = counterReverse / (endN);
constexpr int64_t startN = counterReverse % endN;
zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]);
aux_setzero<endM, endN, counter - 1>(zmm);
}
template <int64_t endM, int64_t endN, int64_t counter>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
EIGEN_UNUSED_VARIABLE(zmm);
}
/**
* aux_updateC
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
**/
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endM * endN - counter;
constexpr int64_t startM = counterReverse / (endN);
constexpr int64_t startN = counterReverse % endN;
EIGEN_IF_CONSTEXPR(rem)
zmm.packet[startN * endM + startM] =
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
else zmm.packet[startN * endM + startM] =
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
}
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(C_arr);
EIGEN_UNUSED_VARIABLE(LDC);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_storeC
*
* 2-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
**/
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endM * endN - counter;
constexpr int64_t startM = counterReverse / (endN);
constexpr int64_t startN = counterReverse % endN;
EIGEN_IF_CONSTEXPR(rem)
pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
remMask<PacketSize>(rem_));
else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
}
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(C_arr);
EIGEN_UNUSED_VARIABLE(LDC);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_startLoadB
*
* 1-D unroll
* for(startL = 0; startL < endL; startL++)
**/
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endL - counter;
constexpr int64_t startL = counterReverse;
EIGEN_IF_CONSTEXPR(rem)
zmm.packet[unrollM * unrollN + startL] =
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
else zmm.packet[unrollM * unrollN + startL] =
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
}
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(B_t);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_startBCastA
*
* 1-D unroll
* for(startB = 0; startB < endB; startB++)
**/
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
constexpr int64_t counterReverse = endB - counter;
constexpr int64_t startB = counterReverse;
zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
}
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
EIGEN_UNUSED_VARIABLE(A_t);
EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(zmm);
}
/**
* aux_loadB
* currK: current K
*
* 1-D unroll
* for(startM = 0; startM < endM; startM++)
**/
template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
if ((numLoad / endM + currK < unrollK)) {
constexpr int64_t counterReverse = endM - counter;
constexpr int64_t startM = counterReverse;
EIGEN_IF_CONSTEXPR(rem) {
zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
}
else {
zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
}
aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
}
}
template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(B_t);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/**
* aux_microKernel
*
* 3-D unroll
* for(startM = 0; startM < endM; startM++)
* for(startN = 0; startN < endN; startN++)
* for(startK = 0; startK < endK; startK++)
**/
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
constexpr int64_t counterReverse = endM * endN * endK - counter;
constexpr int startK = counterReverse / (endM * endN);
constexpr int startN = (counterReverse / (endM)) % endN;
constexpr int startM = counterReverse % endM;
EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
}
{
// Interleave FMA and Bcast
EIGEN_IF_CONSTEXPR(isAdd) {
zmm.packet[startN * endM + startM] =
pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
}
else {
zmm.packet[startN * endM + startM] =
pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
}
// Bcast
EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
(numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
}
}
// We have updated all accumulators, time to load next set of B's
EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) {
gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
}
aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
}
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(B_t);
EIGEN_UNUSED_VARIABLE(A_t);
EIGEN_UNUSED_VARIABLE(LDB);
EIGEN_UNUSED_VARIABLE(LDA);
EIGEN_UNUSED_VARIABLE(zmm);
EIGEN_UNUSED_VARIABLE(rem_);
}
/********************************************************
* Wrappers for aux_XXXX to hide counter parameter
********************************************************/
template <int64_t endM, int64_t endN>
static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
aux_setzero<endM, endN, endM * endN>(zmm);
}
/**
* Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
*/
template <int64_t endM, int64_t endN, bool rem = false>
static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
}
template <int64_t endM, int64_t endN, bool rem = false>
static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
}
/**
* Use numLoad registers for loading B at start of microKernel
*/
template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
}
/**
* Use numBCast registers for broadcasting A at start of microKernel
*/
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
}
/**
* Loads next set of B into vector registers between each K unroll.
*/
template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
}
/**
* Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B.
* A matrix can be row/col-major. B matrix is assumed row-major.
*
* isARowMajor: is A row major
* endM: Number registers per row
* endN: Number of rows
* endK: Loop unroll for K.
* numLoad: Number of registers for loading B.
* numBCast: Number of registers for broadcasting A.
*
* Ex: microkernel<isARowMajor,0,3,0,4,0,4,6,2>: 8x48 unroll (24 accumulators), k unrolled 4 times,
* 6 register for loading B, 2 for broadcasting A.
*
* Note: Ideally the microkernel should not have any register spilling.
* The avx instruction counts should be:
* - endK*endN vbroadcasts{s,d}
* - endK*endM vmovup{s,d}
* - endK*endN*endM FMAs
*
* From testing, there are no register spills with clang. There are register spills with GNU, which
* causes a performance hit.
*/
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
bool rem = false>
static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
int64_t rem_ = 0) {
EIGEN_UNUSED_VARIABLE(rem_);
aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
rem_);
}
};
} // namespace unrolls
#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TYPE_CASTING_AVX512_H
#define EIGEN_TYPE_CASTING_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
template <>
struct type_casting_traits<float, bool> : vectorized_type_casting_traits<float, bool> {};
template <>
struct type_casting_traits<bool, float> : vectorized_type_casting_traits<bool, float> {};
template <>
struct type_casting_traits<float, int> : vectorized_type_casting_traits<float, int> {};
template <>
struct type_casting_traits<int, float> : vectorized_type_casting_traits<int, float> {};
template <>
struct type_casting_traits<float, double> : vectorized_type_casting_traits<float, double> {};
template <>
struct type_casting_traits<double, float> : vectorized_type_casting_traits<double, float> {};
template <>
struct type_casting_traits<double, int> : vectorized_type_casting_traits<double, int> {};
template <>
struct type_casting_traits<int, double> : vectorized_type_casting_traits<int, double> {};
template <>
struct type_casting_traits<double, int64_t> : vectorized_type_casting_traits<double, int64_t> {};
template <>
struct type_casting_traits<int64_t, double> : vectorized_type_casting_traits<int64_t, double> {};
template <>
struct type_casting_traits<half, float> : vectorized_type_casting_traits<half, float> {};
template <>
struct type_casting_traits<float, half> : vectorized_type_casting_traits<float, half> {};
template <>
struct type_casting_traits<bfloat16, float> : vectorized_type_casting_traits<bfloat16, float> {};
template <>
struct type_casting_traits<float, bfloat16> : vectorized_type_casting_traits<float, bfloat16> {};
template <>
EIGEN_STRONG_INLINE Packet16b pcast<Packet16f, Packet16b>(const Packet16f& a) {
__mmask16 mask = _mm512_cmpneq_ps_mask(a, pzero(a));
return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1));
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16b, Packet16f>(const Packet16b& a) {
return _mm512_cvtepi32_ps(_mm512_and_si512(_mm512_cvtepi8_epi32(a), _mm512_set1_epi32(1)));
}
template <>
EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
return _mm512_cvttps_epi32(a);
}
template <>
EIGEN_STRONG_INLINE Packet8d pcast<Packet16f, Packet8d>(const Packet16f& a) {
return _mm512_cvtps_pd(_mm512_castps512_ps256(a));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcast<Packet8f, Packet8d>(const Packet8f& a) {
return _mm512_cvtps_pd(a);
}
template <>
EIGEN_STRONG_INLINE Packet8l pcast<Packet8d, Packet8l>(const Packet8d& a) {
#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
return _mm512_cvttpd_epi64(a);
#else
constexpr int kTotalBits = sizeof(double) * CHAR_BIT, kMantissaBits = std::numeric_limits<double>::digits - 1,
kExponentBits = kTotalBits - kMantissaBits - 1, kBias = (1 << (kExponentBits - 1)) - 1;
const __m512i cst_one = _mm512_set1_epi64(1);
const __m512i cst_total_bits = _mm512_set1_epi64(kTotalBits);
const __m512i cst_bias = _mm512_set1_epi64(kBias);
__m512i a_bits = _mm512_castpd_si512(a);
// shift left by 1 to clear the sign bit, and shift right by kMantissaBits + 1 to recover biased exponent
__m512i biased_e = _mm512_srli_epi64(_mm512_slli_epi64(a_bits, 1), kMantissaBits + 1);
__m512i e = _mm512_sub_epi64(biased_e, cst_bias);
// shift to the left by kExponentBits + 1 to clear the sign and exponent bits
__m512i shifted_mantissa = _mm512_slli_epi64(a_bits, kExponentBits + 1);
// shift to the right by kTotalBits - e to convert the significand to an integer
__m512i result_significand = _mm512_srlv_epi64(shifted_mantissa, _mm512_sub_epi64(cst_total_bits, e));
// add the implied bit
__m512i result_exponent = _mm512_sllv_epi64(cst_one, e);
// e <= 0 is interpreted as a large positive shift (2's complement), which also conveniently results in zero
__m512i result = _mm512_add_epi64(result_significand, result_exponent);
// handle negative arguments
__mmask8 sign_mask = _mm512_cmplt_epi64_mask(a_bits, _mm512_setzero_si512());
result = _mm512_mask_sub_epi64(result, sign_mask, _mm512_setzero_si512(), result);
return result;
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
return _mm512_cvtepi32_ps(a);
}
template <>
EIGEN_STRONG_INLINE Packet8d pcast<Packet16i, Packet8d>(const Packet16i& a) {
return _mm512_cvtepi32_pd(_mm512_castsi512_si256(a));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcast<Packet8i, Packet8d>(const Packet8i& a) {
return _mm512_cvtepi32_pd(a);
}
template <>
EIGEN_STRONG_INLINE Packet8d pcast<Packet8l, Packet8d>(const Packet8l& a) {
#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
return _mm512_cvtepi64_pd(a);
#else
EIGEN_ALIGN64 int64_t aux[8];
pstore(aux, a);
return _mm512_set_pd(static_cast<double>(aux[7]), static_cast<double>(aux[6]), static_cast<double>(aux[5]),
static_cast<double>(aux[4]), static_cast<double>(aux[3]), static_cast<double>(aux[2]),
static_cast<double>(aux[1]), static_cast<double>(aux[0]));
#endif
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet8d, Packet16f>(const Packet8d& a, const Packet8d& b) {
return cat256(_mm512_cvtpd_ps(a), _mm512_cvtpd_ps(b));
}
template <>
EIGEN_STRONG_INLINE Packet16i pcast<Packet8d, Packet16i>(const Packet8d& a, const Packet8d& b) {
return cat256i(_mm512_cvttpd_epi32(a), _mm512_cvttpd_epi32(b));
}
template <>
EIGEN_STRONG_INLINE Packet8i pcast<Packet8d, Packet8i>(const Packet8d& a) {
return _mm512_cvtpd_epi32(a);
}
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet8d, Packet8f>(const Packet8d& a) {
return _mm512_cvtpd_ps(a);
}
template <>
EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
return _mm512_castps_si512(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet16i>(const Packet16i& a) {
return _mm512_castsi512_ps(a);
}
template <>
EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet16f>(const Packet16f& a) {
return _mm512_castps_pd(a);
}
template <>
EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8l>(const Packet8l& a) {
return _mm512_castsi512_pd(a);
}
template <>
EIGEN_STRONG_INLINE Packet8l preinterpret<Packet8l, Packet8d>(const Packet8d& a) {
return _mm512_castpd_si512(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8d>(const Packet8d& a) {
return _mm512_castpd_ps(a);
}
template <>
EIGEN_STRONG_INLINE Packet8f preinterpret<Packet8f, Packet16f>(const Packet16f& a) {
return _mm512_castps512_ps256(a);
}
template <>
EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f, Packet16f>(const Packet16f& a) {
return _mm512_castps512_ps128(a);
}
template <>
EIGEN_STRONG_INLINE Packet4d preinterpret<Packet4d, Packet8d>(const Packet8d& a) {
return _mm512_castpd512_pd256(a);
}
template <>
EIGEN_STRONG_INLINE Packet2d preinterpret<Packet2d, Packet8d>(const Packet8d& a) {
return _mm512_castpd512_pd128(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet8f>(const Packet8f& a) {
return _mm512_castps256_ps512(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f preinterpret<Packet16f, Packet4f>(const Packet4f& a) {
return _mm512_castps128_ps512(a);
}
template <>
EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet4d>(const Packet4d& a) {
return _mm512_castpd256_pd512(a);
}
template <>
EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet2d>(const Packet2d& a) {
return _mm512_castpd128_pd512(a);
}
template <>
EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i, Packet16i>(const Packet16i& a) {
return _mm512_castsi512_si256(a);
}
template <>
EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet16i>(const Packet16i& a) {
return _mm512_castsi512_si128(a);
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet16h>(const Packet16h& a) {
return _mm256_castsi256_si128(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
return half2float(a);
}
template <>
EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
return float2half(a);
}
#endif
template <>
EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
return _mm256_castsi256_si128(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
return Bf16ToF32(a);
}
template <>
EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a) {
return F32ToBf16(a);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_TYPE_CASTING_AVX512_H
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TYPE_CASTING_FP16_AVX512_H
#define EIGEN_TYPE_CASTING_FP16_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
template <>
EIGEN_STRONG_INLINE Packet32s preinterpret<Packet32s, Packet32h>(const Packet32h& a) {
return _mm512_castph_si512(a);
}
template <>
EIGEN_STRONG_INLINE Packet16s preinterpret<Packet16s, Packet16h>(const Packet16h& a) {
return _mm256_castph_si256(a);
}
template <>
EIGEN_STRONG_INLINE Packet8s preinterpret<Packet8s, Packet8h>(const Packet8h& a) {
return _mm_castph_si128(a);
}
template <>
EIGEN_STRONG_INLINE Packet32h preinterpret<Packet32h, Packet32s>(const Packet32s& a) {
return _mm512_castsi512_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet16s>(const Packet16s& a) {
return _mm256_castsi256_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet8s>(const Packet8s& a) {
return _mm_castsi128_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
return half2float(a);
}
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
return half2float(a);
}
template <>
EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
return float2half(a);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
return float2half(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
// Discard second-half of input.
Packet16h low = _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
return _mm512_cvtxph_ps(low);
}
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
// Discard second-half of input.
Packet8h low = _mm_castps_ph(_mm256_extractf32x4_ps(_mm256_castph_ps(a), 0));
return _mm256_cvtxph_ps(low);
}
template <>
EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
Packet8f full = _mm256_cvtxph_ps(a);
// Discard second-half of input.
return _mm256_extractf32x4_ps(full, 0);
}
template <>
EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
__m512 result = _mm512_castsi512_ps(_mm512_castsi256_si512(_mm256_castph_si256(_mm512_cvtxps_ph(a))));
result = _mm512_insertf32x8(result, _mm256_castph_ps(_mm512_cvtxps_ph(b)), 1);
return _mm512_castps_ph(result);
}
template <>
EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
__m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castph_si128(_mm256_cvtxps_ph(a))));
result = _mm256_insertf32x4(result, _mm_castph_ps(_mm256_cvtxps_ph(b)), 1);
return _mm256_castps_ph(result);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
__m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castps_si128(a)));
result = _mm256_insertf128_ps(result, b, 1);
return _mm256_cvtxps_ph(result);
}
template <>
EIGEN_STRONG_INLINE Packet32s pcast<Packet32h, Packet32s>(const Packet32h& a) {
return _mm512_cvtph_epi16(a);
}
template <>
EIGEN_STRONG_INLINE Packet16s pcast<Packet16h, Packet16s>(const Packet16h& a) {
return _mm256_cvtph_epi16(a);
}
template <>
EIGEN_STRONG_INLINE Packet8s pcast<Packet8h, Packet8s>(const Packet8h& a) {
return _mm_cvtph_epi16(a);
}
template <>
EIGEN_STRONG_INLINE Packet32h pcast<Packet32s, Packet32h>(const Packet32s& a) {
return _mm512_cvtepi16_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet16h pcast<Packet16s, Packet16h>(const Packet16s& a) {
return _mm256_cvtepi16_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcast<Packet8s, Packet8h>(const Packet8s& a) {
return _mm_cvtepi16_ph(a);
}
} // namespace internal
} // namespace Eigen
#endif // EIGEN_TYPE_CASTING_FP16_AVX512_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment