Commit f10bfbf9 authored by carlushuang's avatar carlushuang
Browse files

add avx2 intrinsic

parent 3cc7ac0a
......@@ -163,6 +163,10 @@
#define CK_WORKAROUND_GITHUB_135 1
#endif
#ifndef CK_USE_X86_INLINE_ASM
#define CK_USE_X86_INLINE_ASM 1
#endif
namespace ck {
enum struct InMemoryDataOperationEnum_t
......
#ifndef CK_THREADWISE_GEMM_AVX2_HPP
#define CK_THREADWISE_GEMM_AVX2_HPP
#if CK_USE_X86_INLINE_ASM == 0
#include <immintrin.h>
#endif
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "math.hpp"
......@@ -51,7 +54,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
* lda/ldb/ldc all in unit of byte
*
*/
#if CK_USE_X86_INLINE_ASM
// clang-format off
__asm__ __volatile__ (
"L_GemmAvx2_MxN_6x16_Entry%=:\n"
......@@ -326,6 +329,197 @@ struct ThreadwiseGemmAvx2_MxN_6x16
"ymm14","ymm15"
);
// clang-format on
#else
__m256 ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11, ymm12,
ymm13, ymm14, ymm15;
const FloatA* p_a = reinterpret_cast<const FloatA*>(param->p_a);
const FloatB* p_b = reinterpret_cast<const FloatB*>(param->p_b);
float* p_c = reinterpret_cast<float*>(param->p_c);
uint64_t Kr = param->Kr;
uint64_t lda = param->lda / sizeof(FloatA);
uint64_t ldb = param->ldb / sizeof(FloatB);
uint64_t ldc = param->ldc / sizeof(float);
// float alpha = param->alpha;
auto broadcast_a = [&](const int i_k, const int i_m, __m256& ymm) {
if constexpr(std::is_same<FloatA, float>::value)
{
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value)
{
ymm = _mm256_broadcast_ss(p_a + i_m * lda + i_k);
}
else
{
ymm = _mm256_broadcast_ss(p_a + i_k * Mr + i_m);
}
}
else
{
// static_assert();
// not supported for now. Maybe for intrinsic never use fp16 input and cvt and
// broadcast to ymm (don't have enough register)
// below code seems result in computation fail...
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value)
{
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_m * lda + i_k)));
}
else
{
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * Mr + i_m)));
}
}
};
auto load_b = [&](const int i_k, const int i_n, __m256& ymm) {
if constexpr(std::is_same<FloatB, float>::value)
{
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
ymm = _mm256_loadu_ps(p_b + i_k * Nr + i_n * 8);
}
else
{
ymm = _mm256_loadu_ps(p_b + i_k * 8 + i_n * ldb);
}
}
else
{
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * Nr + i_n * 8)));
}
else
{
ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * 8 + i_n * ldb)));
}
}
};
// clang-format off
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr (Mr > 1 ) ymm2 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 2 ) ymm4 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 3 ) ymm6 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 4 ) ymm8 = _mm256_loadu_ps(p_c + 4 * ldc + 0 * 8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_loadu_ps(p_c + 4 * ldc + 1 * 8);
if constexpr (Mr > 5 ) ymm10 = _mm256_loadu_ps(p_c + 5 * ldc + 0 * 8);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_loadu_ps(p_c + 5 * ldc + 1 * 8);
while (Kr > 4){
#pragma unroll
for(int i_k = 0; i_k < 4; i_k++){
load_b(i_k, 0, ymm12);
if constexpr ( Nr > 8) load_b(i_k, 1, ymm13);
broadcast_a(i_k, 0, ymm14);
if constexpr (Mr > 1 ) broadcast_a(i_k, 1, ymm15);
ymm0 = _mm256_fmadd_ps(ymm12, ymm14, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_fmadd_ps(ymm13, ymm14, ymm1);
if constexpr (Mr > 1 ) ymm2 = _mm256_fmadd_ps(ymm12, ymm15, ymm2);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_fmadd_ps(ymm13, ymm15, ymm3);
if constexpr (Mr > 2 ) broadcast_a(i_k, 2, ymm14);
if constexpr (Mr > 3 ) broadcast_a(i_k, 3, ymm15);
if constexpr (Mr > 2 ) ymm4 = _mm256_fmadd_ps(ymm12, ymm14, ymm4);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_fmadd_ps(ymm13, ymm14, ymm5);
if constexpr (Mr > 3 ) ymm6 = _mm256_fmadd_ps(ymm12, ymm15, ymm6);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_fmadd_ps(ymm13, ymm15, ymm7);
if constexpr (Mr > 4 ) broadcast_a(i_k, 4, ymm14);
if constexpr (Mr > 5 ) broadcast_a(i_k, 5, ymm15);
if constexpr (Mr > 4 ) ymm8 = _mm256_fmadd_ps(ymm12, ymm14, ymm8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_fmadd_ps(ymm13, ymm14, ymm9);
if constexpr (Mr > 5 ) ymm10 = _mm256_fmadd_ps(ymm12, ymm15, ymm10);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_fmadd_ps(ymm13, ymm15, ymm11);
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 4;
} else{
p_a += Mr * 4;
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 4;
}else{
p_b += 4 * 8;
}
Kr -= 4;
}
while (Kr != 0){
load_b(0, 0, ymm12);
if constexpr ( Nr > 8) load_b(0, 1, ymm13);
broadcast_a(0, 0, ymm14);
if constexpr (Mr > 1 ) broadcast_a(0, 1, ymm15);
ymm0 = _mm256_fmadd_ps(ymm12, ymm14, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_fmadd_ps(ymm13, ymm14, ymm1);
if constexpr (Mr > 1 ) ymm2 = _mm256_fmadd_ps(ymm12, ymm15, ymm2);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_fmadd_ps(ymm13, ymm15, ymm3);
if constexpr (Mr > 2 ) broadcast_a(0, 2, ymm14);
if constexpr (Mr > 3 ) broadcast_a(0, 3, ymm15);
if constexpr (Mr > 2 ) ymm4 = _mm256_fmadd_ps(ymm12, ymm14, ymm4);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_fmadd_ps(ymm13, ymm14, ymm5);
if constexpr (Mr > 3 ) ymm6 = _mm256_fmadd_ps(ymm12, ymm15, ymm6);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_fmadd_ps(ymm13, ymm15, ymm7);
if constexpr (Mr > 4 ) broadcast_a(0, 4, ymm14);
if constexpr (Mr > 5 ) broadcast_a(0, 5, ymm15);
if constexpr (Mr > 4 ) ymm8 = _mm256_fmadd_ps(ymm12, ymm14, ymm8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_fmadd_ps(ymm13, ymm14, ymm9);
if constexpr (Mr > 5 ) ymm10 = _mm256_fmadd_ps(ymm12, ymm15, ymm10);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_fmadd_ps(ymm13, ymm15, ymm11);
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 1;
} else{
p_a += Mr * 1;
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 1;
}else{
p_b += 1 * 8;
}
Kr--;
}
if(param->alpha != 1.0f){
ymm12 = _mm256_broadcast_ss(reinterpret_cast<float const*>(&param->alpha));
ymm0 = _mm256_mul_ps(ymm12, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_mul_ps(ymm12, ymm1);
if constexpr (Mr > 1 ) ymm2 = _mm256_mul_ps(ymm12, ymm2);
if constexpr (Mr > 1 && Nr > 8) ymm3 = _mm256_mul_ps(ymm12, ymm3);
if constexpr (Mr > 2 ) ymm4 = _mm256_mul_ps(ymm12, ymm4);
if constexpr (Mr > 2 && Nr > 8) ymm5 = _mm256_mul_ps(ymm12, ymm5);
if constexpr (Mr > 3 ) ymm6 = _mm256_mul_ps(ymm12, ymm6);
if constexpr (Mr > 3 && Nr > 8) ymm7 = _mm256_mul_ps(ymm12, ymm7);
if constexpr (Mr > 4 ) ymm8 = _mm256_mul_ps(ymm12, ymm8);
if constexpr (Mr > 4 && Nr > 8) ymm9 = _mm256_mul_ps(ymm12, ymm9);
if constexpr (Mr > 5 ) ymm10 = _mm256_mul_ps(ymm12, ymm10);
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_mul_ps(ymm12, ymm11);
}
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm3);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm4);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm5);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 4 ) _mm256_storeu_ps(p_c + 4 * ldc + 0 * 8, ymm8);
if constexpr (Mr > 4 && Nr > 8) _mm256_storeu_ps(p_c + 4 * ldc + 1 * 8, ymm9);
if constexpr (Mr > 5 ) _mm256_storeu_ps(p_c + 5 * ldc + 0 * 8, ymm10);
if constexpr (Mr > 5 && Nr > 8) _mm256_storeu_ps(p_c + 5 * ldc + 1 * 8, ymm11);
// clang-format on
#endif
}
};
......@@ -370,7 +564,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
* lda/ldb/ldc all in unit of byte
*
*/
#if CK_USE_X86_INLINE_ASM
// clang-format off
__asm__ __volatile__ (
"L_GemmAvx2_MxN_4x24_Entry%=:\n"
......@@ -641,6 +835,197 @@ struct ThreadwiseGemmAvx2_MxN_4x24
"ymm14","ymm15"
);
// clang-format on
#else
__m256 ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11, ymm12,
ymm13, ymm14, ymm15;
const FloatA* p_a = reinterpret_cast<const FloatA*>(param->p_a);
const FloatB* p_b = reinterpret_cast<const FloatB*>(param->p_b);
float* p_c = reinterpret_cast<float*>(param->p_c);
uint64_t Kr = param->Kr;
uint64_t lda = param->lda / sizeof(FloatA);
uint64_t ldb = param->ldb / sizeof(FloatB);
uint64_t ldc = param->ldc / sizeof(float);
// float alpha = param->alpha;
auto broadcast_a = [&](const int i_k, const int i_m, __m256& ymm) {
if constexpr(std::is_same<FloatA, float>::value)
{
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value)
{
ymm = _mm256_broadcast_ss(p_a + i_m * lda + i_k);
}
else
{
ymm = _mm256_broadcast_ss(p_a + i_k * Mr + i_m);
}
}
else
{
// static_assert();
// not supported for now. Maybe for intrinsic never use fp16 input and cvt and
// broadcast to ymm (don't have enough register)
// below code seems result in computation fail...
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value)
{
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_m * lda + i_k)));
}
else
{
ymm = _mm256_cvtph_ps(_mm_set1_epi16(*(p_a + i_k * Mr + i_m)));
}
}
};
auto load_b = [&](const int i_k, const int i_n, __m256& ymm) {
if constexpr(std::is_same<FloatB, float>::value)
{
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
ymm = _mm256_loadu_ps(p_b + i_k * Nr + i_n * 8);
}
else
{
ymm = _mm256_loadu_ps(p_b + i_k * 8 + i_n * ldb);
}
}
else
{
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value)
{
ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * Nr + i_n * 8)));
}
else
{
ymm = _mm256_cvtph_ps(_mm_loadu_si128(
reinterpret_cast<__m128i const*>(p_b + i_k * 8 + i_n * ldb)));
}
}
};
// clang-format off
ymm0 = _mm256_loadu_ps(p_c + 0 * ldc + 0 * 8);
if constexpr ( Nr > 8) ymm1 = _mm256_loadu_ps(p_c + 0 * ldc + 1 * 8);
if constexpr ( Nr >16) ymm2 = _mm256_loadu_ps(p_c + 0 * ldc + 2 * 8);
if constexpr (Mr > 1 ) ymm3 = _mm256_loadu_ps(p_c + 1 * ldc + 0 * 8);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_loadu_ps(p_c + 1 * ldc + 1 * 8);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_loadu_ps(p_c + 1 * ldc + 2 * 8);
if constexpr (Mr > 2 ) ymm6 = _mm256_loadu_ps(p_c + 2 * ldc + 0 * 8);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_loadu_ps(p_c + 2 * ldc + 1 * 8);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_loadu_ps(p_c + 2 * ldc + 2 * 8);
if constexpr (Mr > 3 ) ymm9 = _mm256_loadu_ps(p_c + 3 * ldc + 0 * 8);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_loadu_ps(p_c + 3 * ldc + 1 * 8);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_loadu_ps(p_c + 3 * ldc + 2 * 8);
while (Kr > 4){
#pragma unroll
for(int i_k = 0; i_k < 4; i_k++){
load_b(i_k, 0, ymm12);
if constexpr ( Nr > 8) load_b(i_k, 1, ymm13);
if constexpr ( Nr >16) load_b(i_k, 2, ymm14);
broadcast_a(i_k, 0, ymm15);
ymm0 = _mm256_fmadd_ps(ymm12, ymm15, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_fmadd_ps(ymm13, ymm15, ymm1);
if constexpr ( Nr >16) ymm2 = _mm256_fmadd_ps(ymm14, ymm15, ymm2);
if constexpr (Mr > 1 ) broadcast_a(i_k, 1, ymm15);
if constexpr (Mr > 1 ) ymm3 = _mm256_fmadd_ps(ymm12, ymm15, ymm3);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_fmadd_ps(ymm13, ymm15, ymm4);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_fmadd_ps(ymm14, ymm15, ymm5);
if constexpr (Mr > 2 ) broadcast_a(i_k, 2, ymm15);
if constexpr (Mr > 2 ) ymm6 = _mm256_fmadd_ps(ymm12, ymm15, ymm6);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_fmadd_ps(ymm13, ymm15, ymm7);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_fmadd_ps(ymm14, ymm15, ymm8);
if constexpr (Mr > 3 ) broadcast_a(i_k, 3, ymm15);
if constexpr (Mr > 3 ) ymm9 = _mm256_fmadd_ps(ymm12, ymm15, ymm9);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_fmadd_ps(ymm13, ymm15, ymm10);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_fmadd_ps(ymm14, ymm15, ymm11);
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 4;
} else{
p_a += Mr * 4;
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 4;
}else{
p_b += 4 * 8;
}
Kr -= 4;
}
while (Kr != 0){
load_b(0, 0, ymm12);
if constexpr ( Nr > 8) load_b(0, 1, ymm13);
if constexpr ( Nr >16) load_b(0, 2, ymm14);
broadcast_a(0, 0, ymm15);
ymm0 = _mm256_fmadd_ps(ymm12, ymm15, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_fmadd_ps(ymm13, ymm15, ymm1);
if constexpr ( Nr >16) ymm2 = _mm256_fmadd_ps(ymm14, ymm15, ymm2);
if constexpr (Mr > 1 ) broadcast_a(0, 1, ymm15);
if constexpr (Mr > 1 ) ymm3 = _mm256_fmadd_ps(ymm12, ymm15, ymm3);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_fmadd_ps(ymm13, ymm15, ymm4);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_fmadd_ps(ymm14, ymm15, ymm5);
if constexpr (Mr > 2 ) broadcast_a(0, 2, ymm15);
if constexpr (Mr > 2 ) ymm6 = _mm256_fmadd_ps(ymm12, ymm15, ymm6);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_fmadd_ps(ymm13, ymm15, ymm7);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_fmadd_ps(ymm14, ymm15, ymm8);
if constexpr (Mr > 3 ) broadcast_a(0, 3, ymm15);
if constexpr (Mr > 3 ) ymm9 = _mm256_fmadd_ps(ymm12, ymm15, ymm9);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_fmadd_ps(ymm13, ymm15, ymm10);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_fmadd_ps(ymm14, ymm15, ymm11);
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value){
p_a += 1;
} else{
p_a += Mr * 1;
}
if constexpr(std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value){
p_b += Nr * 1;
}else{
p_b += 1 * 8;
}
Kr--;
}
if(param->alpha != 1.0f){
ymm12 = _mm256_broadcast_ss(reinterpret_cast<float const*>(&param->alpha));
ymm0 = _mm256_mul_ps(ymm12, ymm0);
if constexpr ( Nr > 8) ymm1 = _mm256_mul_ps(ymm12, ymm1);
if constexpr ( Nr >16) ymm2 = _mm256_mul_ps(ymm12, ymm2);
if constexpr (Mr > 1 ) ymm3 = _mm256_mul_ps(ymm12, ymm3);
if constexpr (Mr > 1 && Nr > 8) ymm4 = _mm256_mul_ps(ymm12, ymm4);
if constexpr (Mr > 1 && Nr >16) ymm5 = _mm256_mul_ps(ymm12, ymm5);
if constexpr (Mr > 2 ) ymm6 = _mm256_mul_ps(ymm12, ymm6);
if constexpr (Mr > 2 && Nr > 8) ymm7 = _mm256_mul_ps(ymm12, ymm7);
if constexpr (Mr > 2 && Nr >16) ymm8 = _mm256_mul_ps(ymm12, ymm8);
if constexpr (Mr > 3 ) ymm9 = _mm256_mul_ps(ymm12, ymm9);
if constexpr (Mr > 3 && Nr > 8) ymm10 = _mm256_mul_ps(ymm12, ymm10);
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_mul_ps(ymm12, ymm11);
}
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr ( Nr >16) _mm256_storeu_ps(p_c + 0 * ldc + 2 * 8, ymm2);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm3);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm4);
if constexpr (Mr > 1 && Nr >16) _mm256_storeu_ps(p_c + 1 * ldc + 2 * 8, ymm5);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 2 && Nr >16) _mm256_storeu_ps(p_c + 2 * ldc + 2 * 8, ymm8);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm9);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm10);
if constexpr (Mr > 3 && Nr >16) _mm256_storeu_ps(p_c + 3 * ldc + 2 * 8, ymm11);
// clang-format on
#endif
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment