Commit e6ee6594 authored by carlushuang's avatar carlushuang
Browse files

non-temporal store support

parent a6e310af
...@@ -297,6 +297,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -297,6 +297,7 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n" ".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n" ".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n" " vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n" ".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovups %%ymm2, (%%rbx) \n .endif\n" ".if (m_Mr > 1) \n vmovups %%ymm2, (%%rbx) \n .endif\n"
...@@ -309,6 +310,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -309,6 +310,20 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 4) && (m_Nr > 8)\n vmovups %%ymm9, 32(%%r8) \n .endif\n" ".if (m_Mr > 4) && (m_Nr > 8)\n vmovups %%ymm9, 32(%%r8) \n .endif\n"
".if (m_Mr > 5) \n vmovups %%ymm10, (%%r9) \n .endif\n" ".if (m_Mr > 5) \n vmovups %%ymm10, (%%r9) \n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vmovups %%ymm11, 32(%%r9) \n .endif\n" ".if (m_Mr > 5) && (m_Nr > 8)\n vmovups %%ymm11, 32(%%r9) \n .endif\n"
".else\n"
" vmovntps %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovntps %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovntps %%ymm2, (%%rbx) \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vmovntps %%ymm3, 32(%%rbx)\n .endif\n"
".if (m_Mr > 2) \n vmovntps %%ymm4, (%%rcx) \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vmovntps %%ymm5, 32(%%rcx)\n .endif\n"
".if (m_Mr > 3) \n vmovntps %%ymm6, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovntps %%ymm7, 32(%%rdx)\n .endif\n"
".if (m_Mr > 4) \n vmovntps %%ymm8, (%%r8) \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vmovntps %%ymm9, 32(%%r8) \n .endif\n"
".if (m_Mr > 5) \n vmovntps %%ymm10, (%%r9) \n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vmovntps %%ymm11, 32(%%r9) \n .endif\n"
".endif\n"
"L_GemmAvx2_MxN_6x16_Exit%=:\n" "L_GemmAvx2_MxN_6x16_Exit%=:\n"
: :
: :
...@@ -506,19 +521,34 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -506,19 +521,34 @@ struct ThreadwiseGemmAvx2_MxN_6x16
if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_mul_ps(ymm12, ymm11); if constexpr (Mr > 5 && Nr > 8) ymm11 = _mm256_mul_ps(ymm12, ymm11);
} }
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0); if constexpr (NonTemporalStore) {
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1); if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm2); if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm3); if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm3);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm4); if constexpr (Mr > 2 ) _mm256_stream_ps(p_c + 2 * ldc + 0 * 8, ymm4);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm5); if constexpr (Mr > 2 && Nr > 8) _mm256_stream_ps(p_c + 2 * ldc + 1 * 8, ymm5);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm6); if constexpr (Mr > 3 ) _mm256_stream_ps(p_c + 3 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm7); if constexpr (Mr > 3 && Nr > 8) _mm256_stream_ps(p_c + 3 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 4 ) _mm256_storeu_ps(p_c + 4 * ldc + 0 * 8, ymm8); if constexpr (Mr > 4 ) _mm256_stream_ps(p_c + 4 * ldc + 0 * 8, ymm8);
if constexpr (Mr > 4 && Nr > 8) _mm256_storeu_ps(p_c + 4 * ldc + 1 * 8, ymm9); if constexpr (Mr > 4 && Nr > 8) _mm256_stream_ps(p_c + 4 * ldc + 1 * 8, ymm9);
if constexpr (Mr > 5 ) _mm256_storeu_ps(p_c + 5 * ldc + 0 * 8, ymm10); if constexpr (Mr > 5 ) _mm256_stream_ps(p_c + 5 * ldc + 0 * 8, ymm10);
if constexpr (Mr > 5 && Nr > 8) _mm256_storeu_ps(p_c + 5 * ldc + 1 * 8, ymm11); if constexpr (Mr > 5 && Nr > 8) _mm256_stream_ps(p_c + 5 * ldc + 1 * 8, ymm11);
// clang-format on }
else {
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm3);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm4);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm5);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 4 ) _mm256_storeu_ps(p_c + 4 * ldc + 0 * 8, ymm8);
if constexpr (Mr > 4 && Nr > 8) _mm256_storeu_ps(p_c + 4 * ldc + 1 * 8, ymm9);
if constexpr (Mr > 5 ) _mm256_storeu_ps(p_c + 5 * ldc + 0 * 8, ymm10);
if constexpr (Mr > 5 && Nr > 8) _mm256_storeu_ps(p_c + 5 * ldc + 1 * 8, ymm11);
}
// clang-format on
#endif #endif
} }
}; };
...@@ -803,6 +833,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -803,6 +833,7 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n" ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n" ".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
".if m_NTStore == 0\n"
" vmovups %%ymm0, (%%rax) \n" " vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n" ".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Nr >16)\n vmovups %%ymm2, 64(%%rax)\n .endif\n" ".if (m_Nr >16)\n vmovups %%ymm2, 64(%%rax)\n .endif\n"
...@@ -815,6 +846,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -815,6 +846,20 @@ struct ThreadwiseGemmAvx2_MxN_4x24
".if (m_Mr > 3) \n vmovups %%ymm9, (%%rdx) \n .endif\n" ".if (m_Mr > 3) \n vmovups %%ymm9, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovups %%ymm10, 32(%%rdx)\n .endif\n" ".if (m_Mr > 3) && (m_Nr > 8)\n vmovups %%ymm10, 32(%%rdx)\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vmovups %%ymm11, 64(%%rdx)\n .endif\n" ".if (m_Mr > 3) && (m_Nr >16)\n vmovups %%ymm11, 64(%%rdx)\n .endif\n"
".else\n"
" vmovntps %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovntps %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Nr >16)\n vmovntps %%ymm2, 64(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovntps %%ymm3, (%%rbx) \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vmovntps %%ymm4, 32(%%rbx)\n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)\n vmovntps %%ymm5, 64(%%rbx)\n .endif\n"
".if (m_Mr > 2) \n vmovntps %%ymm6, (%%rcx) \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vmovntps %%ymm7, 32(%%rcx)\n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)\n vmovntps %%ymm8, 64(%%rcx)\n .endif\n"
".if (m_Mr > 3) \n vmovntps %%ymm9, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovntps %%ymm10, 32(%%rdx)\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vmovntps %%ymm11, 64(%%rdx)\n .endif\n"
".endif\n"
"L_GemmAvx2_MxN_4x24_Exit%=:\n" "L_GemmAvx2_MxN_4x24_Exit%=:\n"
: :
: :
...@@ -1012,19 +1057,35 @@ struct ThreadwiseGemmAvx2_MxN_4x24 ...@@ -1012,19 +1057,35 @@ struct ThreadwiseGemmAvx2_MxN_4x24
if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_mul_ps(ymm12, ymm11); if constexpr (Mr > 3 && Nr >16) ymm11 = _mm256_mul_ps(ymm12, ymm11);
} }
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0); if constexpr (NonTemporalStore) {
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1); _mm256_stream_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr >16) _mm256_storeu_ps(p_c + 0 * ldc + 2 * 8, ymm2); if constexpr ( Nr > 8) _mm256_stream_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm3); if constexpr ( Nr >16) _mm256_stream_ps(p_c + 0 * ldc + 2 * 8, ymm2);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm4); if constexpr (Mr > 1 ) _mm256_stream_ps(p_c + 1 * ldc + 0 * 8, ymm3);
if constexpr (Mr > 1 && Nr >16) _mm256_storeu_ps(p_c + 1 * ldc + 2 * 8, ymm5); if constexpr (Mr > 1 && Nr > 8) _mm256_stream_ps(p_c + 1 * ldc + 1 * 8, ymm4);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm6); if constexpr (Mr > 1 && Nr >16) _mm256_stream_ps(p_c + 1 * ldc + 2 * 8, ymm5);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm7); if constexpr (Mr > 2 ) _mm256_stream_ps(p_c + 2 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 2 && Nr >16) _mm256_storeu_ps(p_c + 2 * ldc + 2 * 8, ymm8); if constexpr (Mr > 2 && Nr > 8) _mm256_stream_ps(p_c + 2 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm9); if constexpr (Mr > 2 && Nr >16) _mm256_stream_ps(p_c + 2 * ldc + 2 * 8, ymm8);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm10); if constexpr (Mr > 3 ) _mm256_stream_ps(p_c + 3 * ldc + 0 * 8, ymm9);
if constexpr (Mr > 3 && Nr >16) _mm256_storeu_ps(p_c + 3 * ldc + 2 * 8, ymm11); if constexpr (Mr > 3 && Nr > 8) _mm256_stream_ps(p_c + 3 * ldc + 1 * 8, ymm10);
// clang-format on if constexpr (Mr > 3 && Nr >16) _mm256_stream_ps(p_c + 3 * ldc + 2 * 8, ymm11);
}
else {
_mm256_storeu_ps(p_c + 0 * ldc + 0 * 8, ymm0);
if constexpr ( Nr > 8) _mm256_storeu_ps(p_c + 0 * ldc + 1 * 8, ymm1);
if constexpr ( Nr >16) _mm256_storeu_ps(p_c + 0 * ldc + 2 * 8, ymm2);
if constexpr (Mr > 1 ) _mm256_storeu_ps(p_c + 1 * ldc + 0 * 8, ymm3);
if constexpr (Mr > 1 && Nr > 8) _mm256_storeu_ps(p_c + 1 * ldc + 1 * 8, ymm4);
if constexpr (Mr > 1 && Nr >16) _mm256_storeu_ps(p_c + 1 * ldc + 2 * 8, ymm5);
if constexpr (Mr > 2 ) _mm256_storeu_ps(p_c + 2 * ldc + 0 * 8, ymm6);
if constexpr (Mr > 2 && Nr > 8) _mm256_storeu_ps(p_c + 2 * ldc + 1 * 8, ymm7);
if constexpr (Mr > 2 && Nr >16) _mm256_storeu_ps(p_c + 2 * ldc + 2 * 8, ymm8);
if constexpr (Mr > 3 ) _mm256_storeu_ps(p_c + 3 * ldc + 0 * 8, ymm9);
if constexpr (Mr > 3 && Nr > 8) _mm256_storeu_ps(p_c + 3 * ldc + 1 * 8, ymm10);
if constexpr (Mr > 3 && Nr >16) _mm256_storeu_ps(p_c + 3 * ldc + 2 * 8, ymm11);
}
// clang-format on
#endif #endif
} }
}; };
......
...@@ -54,17 +54,18 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -54,17 +54,18 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using AType = float; using AType = float;
using BType = float; using BType = float;
using CType = float; using CType = float;
#define NTStore false
template <typename ALayout, typename BLayout> template <typename ALayout, typename BLayout>
using thread_gemm_avx2_mxn_6x16_instances = std::tuple< using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format off // clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore // FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, false) ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(AType, BType, CType, ALayout, BLayout, false) // ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(AType, BType, CType, ALayout, BLayout, NTStore)
// clang-format on // clang-format on
>; >;
...@@ -72,10 +73,10 @@ template <typename ALayout, typename BLayout> ...@@ -72,10 +73,10 @@ template <typename ALayout, typename BLayout>
using thread_gemm_avx2_mxn_4x24_instances = std::tuple< using thread_gemm_avx2_mxn_4x24_instances = std::tuple<
// clang-format off // clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore // FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false), ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, false) ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE( AType, BType, CType, ALayout, BLayout, NTStore)
// clang-format on // clang-format on
>; >;
...@@ -306,8 +307,10 @@ void test_ukernel(ukenrel_t uk, ...@@ -306,8 +307,10 @@ void test_ukernel(ukenrel_t uk,
#pragma omp parallel reduction(+ : us) #pragma omp parallel reduction(+ : us)
{ {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
float* private_c = reinterpret_cast<float*>(malloc(m * n * sizeof(float))); DeviceAlignedMemCPU private_c_mem(m * n * sizeof(float), 32);
float* private_c = reinterpret_cast<float*>(private_c_mem.mpDeviceBuf);
// float * private_c = mat_c + tid * m * n;
ck::cpu::ThreadwiseGemmParam param; ck::cpu::ThreadwiseGemmParam param;
param.p_a = mat_a; param.p_a = mat_a;
...@@ -343,7 +346,6 @@ void test_ukernel(ukenrel_t uk, ...@@ -343,7 +346,6 @@ void test_ukernel(ukenrel_t uk,
invoke_uk(param, private_c); invoke_uk(param, private_c);
memcpy(mat_c + tid * m * n, private_c, m * n * sizeof(float)); memcpy(mat_c + tid * m * n, private_c, m * n * sizeof(float));
free(private_c);
} }
us = us / max_threads; us = us / max_threads;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment