Commit e6ee6594 authored by carlushuang's avatar carlushuang
Browse files

non-temporal store support

parent a6e310af
......@@ -297,6 +297,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovups %%ymm2, (%%rbx) \n .endif\n"
......@@ -309,6 +310,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4) && (m_Nr > 8)\n vmovups %%ymm9, 32(%%r8) \n .endif\n"
".if (m_Mr > 5) \n vmovups %%ymm10, (%%r9) \n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vmovups %%ymm11, 32(%%r9) \n .endif\n"
".else\n"
" vmovntps %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovntps %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovntps %%ymm2, (%%rbx) \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vmovntps %%ymm3, 32(%%rbx)\n .endif\n"
".if (m_Mr > 2) \n vmovntps %%ymm4, (%%rcx) \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vmovntps %%ymm5, 32(%%rcx)\n .endif\n"
".if (m_Mr > 3) \n vmovntps %%ymm6, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovntps %%ymm7, 32(%%rdx)\n .endif\n"
".if (m_Mr > 4) \n vmovntps %%ymm8, (%%r8) \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vmovntps %%ymm9, 32(%%r8) \n .endif\n"
".if (m_Mr > 5) \n vmovntps %%ymm10, (%%r9) \n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vmovntps %%ymm11, 32(%%r9) \n .endif\n"
".endif\n"
"L_GemmAvx2_MxN_6x16_Exit%=:\n"
:
:
......@@ -506,19 +521,34 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_mul_ps(ymm12, ymm11);
}
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm3);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm4);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm5);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 4 ) _mm256_storeu_ps(p_c + 4 * ldc + 0 * 8, ymm8);
if constexpr (Mr > 4 && Nr > 8) _mm256_storeu_ps(p_c + 4 * ldc + 1 * 8, ymm9);
if constexpr (Mr > 5 ) _mm256_storeu_ps(p_c + 5 * ldc + 0 * 8, ymm10);
if constexpr (Mr > 5 && Nr > 8) _mm256_storeu_ps(p_c + 5 * ldc + 1 * 8, ymm11);
// clang-format on
if constexpr (NonTemporalStore) {
if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm3);
if constexpr (Mr > 2 ) _mm256_stream_ps(p_c + 2 * ldc + 0 * 8, ymm4);
if constexpr (Mr > 2 && Nr > 8) _mm256_stream_ps(p_c + 2 * ldc + 1 * 8, ymm5);
if constexpr (Mr > 3 ) _mm256_stream_ps(p_c + 3 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 3 && Nr > 8) _mm256_stream_ps(p_c + 3 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 4 ) _mm256_stream_ps(p_c + 4 * ldc + 0 * 8, ymm8);
if constexpr (Mr > 4 && Nr > 8) _mm256_stream_ps(p_c + 4 * ldc + 1 * 8, ymm9);
if constexpr (Mr > 5 ) _mm256_stream_ps(p_c + 5 * ldc + 0 * 8, ymm10);
if constexpr (Mr > 5 && Nr > 8) _mm256_stream_ps(p_c + 5 * ldc + 1 * 8, ymm11);
}
else {
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm3);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm4);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm5);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 4 ) _mm256_storeu_ps(p_c + 4 * ldc + 0 * 8, ymm8);
if constexpr (Mr > 4 && Nr > 8) _mm256_storeu_ps(p_c + 4 * ldc + 1 * 8, ymm9);
if constexpr (Mr > 5 ) _mm256_storeu_ps(p_c + 5 * ldc + 0 * 8, ymm10);
if constexpr (Mr > 5 && Nr > 8) _mm256_storeu_ps(p_c + 5 * ldc + 1 * 8, ymm11);
}
// clang-format on
#endif
}
};
......@@ -803,6 +833,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Nr >16)\n vmovups %%ymm2, 64(%%rax)\n .endif\n"
......@@ -815,6 +846,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 3) \n vmovups %%ymm9, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovups %%ymm10, 32(%%rdx)\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vmovups %%ymm11, 64(%%rdx)\n .endif\n"
".else\n"
" vmovntps %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovntps %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Nr >16)\n vmovntps %%ymm2, 64(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovntps %%ymm3, (%%rbx) \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vmovntps %%ymm4, 32(%%rbx)\n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)\n vmovntps %%ymm5, 64(%%rbx)\n .endif\n"
".if (m_Mr > 2) \n vmovntps %%ymm6, (%%rcx) \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vmovntps %%ymm7, 32(%%rcx)\n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)\n vmovntps %%ymm8, 64(%%rcx)\n .endif\n"
".if (m_Mr > 3) \n vmovntps %%ymm9, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovntps %%ymm10, 32(%%rdx)\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vmovntps %%ymm11, 64(%%rdx)\n .endif\n"
".endif\n"
"L_GemmAvx2_MxN_4x24_Exit%=:\n"
:
:
......@@ -1012,19 +1057,35 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_mul_ps(ymm12, ymm11);
}
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr ( Nr >16) _mm256_storeu_ps(p_c + 0 * ldc + 2 * 8, ymm2);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm3);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm4);
if constexpr (Mr > 1 && Nr >16) _mm256_storeu_ps(p_c + 1 * ldc + 2 * 8, ymm5);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 2 && Nr >16) _mm256_storeu_ps(p_c + 2 * ldc + 2 * 8, ymm8);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm9);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm10);
if constexpr (Mr > 3 && Nr >16) _mm256_storeu_ps(p_c + 3 * ldc + 2 * 8, ymm11);
// clang-format on
if constexpr (NonTemporalStore) {
_mm256_stream_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr ( Nr >16) _mm256_stream_ps(p_c + 0 * ldc + 2 * 8, ymm2);
if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm3);
if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm4);
if constexpr (Mr > 1 && Nr >16) _mm256_stream_ps(p_c + 1 * ldc + 2 * 8, ymm5);
if constexpr (Mr > 2 ) _mm256_stream_ps(p_c + 2 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 2 && Nr > 8) _mm256_stream_ps(p_c + 2 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 2 && Nr >16) _mm256_stream_ps(p_c + 2 * ldc + 2 * 8, ymm8);
if constexpr (Mr > 3 ) _mm256_stream_ps(p_c + 3 * ldc + 0 * 8, ymm9);
if constexpr (Mr > 3 && Nr > 8) _mm256_stream_ps(p_c + 3 * ldc + 1 * 8, ymm10);
if constexpr (Mr > 3 && Nr >16) _mm256_stream_ps(p_c + 3 * ldc + 2 * 8, ymm11);
}
else {
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr ( Nr >16) _mm256_storeu_ps(p_c + 0 * ldc + 2 * 8, ymm2);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm3);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm4);
if constexpr (Mr > 1 && Nr >16) _mm256_storeu_ps(p_c + 1 * ldc + 2 * 8, ymm5);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 2 && Nr >16) _mm256_storeu_ps(p_c + 2 * ldc + 2 * 8, ymm8);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm9);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm10);
if constexpr (Mr > 3 && Nr >16) _mm256_storeu_ps(p_c + 3 * ldc + 2 * 8, ymm11);
}
// clang-format on
#endif
}
};
......
......@@ -54,17 +54,18 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using AType = float;
using BType = float;
using CType = float;
#define NTStore false
template <typename ALayout, typename BLayout>
using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false)
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(AType, BType, CType, ALayout, BLayout, false)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(AType, BType, CType, ALayout, BLayout, NTStore)
// clang-format on
>;
......@@ -72,10 +73,10 @@ template <typename ALayout, typename BLayout>
using thread_gemm_avx2_mxn_4x24_instances = std::tuple<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false)
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore)
// clang-format on
>;
......@@ -306,8 +307,10 @@ void test_ukernel(ukenrel_t uk,
#pragma omp parallel reduction(+ : us)
{
int tid = omp_get_thread_num();
float* private_c = reinterpret_cast<float*>(malloc(m * n * sizeof(float)));
int tid = omp_get_thread_num();
DeviceAlignedMemCPU private_c_mem(m * n * sizeof(float), 32);
float* private_c = reinterpret_cast<float*>(private_c_mem.mpDeviceBuf);
// float * private_c = mat_c + tid * m * n;
ck::cpu::ThreadwiseGemmParam param;
param.p_a = mat_a;
......@@ -343,7 +346,6 @@ void test_ukernel(ukenrel_t uk,
invoke_uk(param, private_c);
memcpy(mat_c + tid * m * n, private_c, m * n * sizeof(float));
free(private_c);
}
us = us / max_threads;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment