Commit 66fd7712 authored by carlushuang's avatar carlushuang
Browse files

add 4x24 ukernel

parent 3a4df3da
...@@ -300,6 +300,292 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -300,6 +300,292 @@ struct ThreadwiseGemmAvx2_MxN_6x16
} }
}; };
template <typename FloatA,
typename FloatB,
typename FloatC,
index_t Mr,
index_t Nr,
typename ALayout, // default is k*m, trans->m*k
typename BLayout, // default is n/8*k*n8, trans->k*n
bool NonTemporalStore>
struct ThreadwiseGemmAvx2_MxN_4x24
{
using ALayout_ = ALayout;
using BLayout_ = BLayout;
static constexpr auto Mr_ = Mr;
static constexpr auto Nr_ = Nr;
static constexpr auto NonTemporalStore_ = NonTemporalStore;
__host__ constexpr ThreadwiseGemmAvx2_MxN_4x24()
{
static_assert(Mr <= 4 && Mr >= 1 && (Nr == 8 || Nr == 16 || Nr == 24),
"wrong! Mr x Nr not valid");
}
__host__ static void Run(ThreadwiseGemmParam* param)
{
/* 4x24 ukernel
*
* Mat_B
* |ymm12 |ymm13 |ymm14 |
* Mat_A +--------+--------+--------+
* ymm15 |ymm0 |ymm1 |ymm2 |
* |ymm3 |ymm4 |ymm5 |
* |ymm6 |ymm7 |ymm8 |
* |ymm9 |ymm10 |ymm11 |
*
* ALayout:ColumnMajor (k*m), lda not needed
* ALayout:RowMajor (m*k), lda = k
* BLayout:ColumnMajor (n/8*k*n8), ldb = k*n8. At least this should be 8 continuous n for a
* ymm register BLayout:RowMajor (k*n), ldb not needed
*
* lda/ldb/ldc all in unit of byte
*
*/
// clang-format off
__asm__ __volatile__ (
"L_GemmAvx2_MxN_4x24_Entry%=:\n"
".set m_Mr, %c[m_Mr]\n"
".set m_Nr, %c[m_Nr]\n"
".set m_TransA, %c[m_TransA]\n"
".set m_TransB, %c[m_TransB]\n"
".set m_NTStore, %c[m_NTStore]\n"
".set m_ABytes, %c[m_ABytes]\n"
".set m_BBytes, %c[m_BBytes]\n"
".set m_CBytes, %c[m_CBytes]\n"
"movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda
".endif\n"
".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
"vbroadcastss \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n"
"vbroadcastss \\i_offset(\\r_base), \\ymm\n"
".endif\n"
".endm\n"
".macro vmovups_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
"vmovups \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n"
"vmovups \\i_offset(\\r_base), \\ymm\n"
".endif\n"
".endm\n"
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8), lda in rcx
".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * 4, \\ymm\n"
".else\n"
".if (\\i_m == 0) || (\\i_m == 1)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, \\i_k * 4, \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, \\i_m-2, \\i_k * 4, \\ymm\n"
".endif\n"
".endif\n"
".endm\n"
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1, 2
".if m_BBytes == 4\n"
".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n"
"vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n"
".else\n"
".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n"
"vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n"
".endif\n"
".endm\n"
" vxorps %%ymm0, %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vxorps %%ymm1, %%ymm1, %%ymm1 \n .endif\n"
".if (m_Nr >16)\n vxorps %%ymm2, %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) \n vxorps %%ymm3, %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vxorps %%ymm4, %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)\n vxorps %%ymm5, %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 2) \n vxorps %%ymm6, %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vxorps %%ymm7, %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)\n vxorps %%ymm8, %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 3) \n vxorps %%ymm9, %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vxorps %%ymm10, %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vxorps %%ymm11, %%ymm11, %%ymm11\n .endif\n"
".if m_TransA != 0\n"
".if m_Mr > 2\n"
"lea (%%rax, %%rcx, 2), %%r8\n"
".endif\n"
".endif\n"
"cmp $4, %%rsi\n"
"jl L_GemmAvx2_MxN_4x24_K_Loop_Remain%=\n"
"L_GemmAvx2_MxN_4x24_K_Loop_Start%=:\n"
".irp i_k, 0, 1, 2, 3\n"
" vload_b%= \\i_k, 0, %%ymm12\n" // B
".if (m_Nr > 8)\n vload_b%= \\i_k, 1, %%ymm13\n .endif\n" // B
".if (m_Nr >16)\n vload_b%= \\i_k, 2, %%ymm14\n .endif\n" // B
" vbroadcast_a%= \\i_k, 0, %%ymm15\n" // A broadcast 0
" vfmadd231ps %%ymm12, %%ymm15, %%ymm0\n" // 0x0
".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm1\n .endif\n" // 0x1
".if (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm2\n .endif\n" // 0x2
".if (m_Mr > 1) \n vbroadcast_a%= \\i_k, 1, %%ymm15\n .endif\n" // A broadcast 1
".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm3\n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm4\n .endif\n" // 1x1
".if (m_Mr > 1) && (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm5\n .endif\n" // 1x2
".if (m_Mr > 2) \n vbroadcast_a%= \\i_k, 2, %%ymm15\n .endif\n" // A broadcast 2
".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 2x1
".if (m_Mr > 2) && (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm8\n .endif\n" // 2x2
".if (m_Mr > 3) \n vbroadcast_a%= \\i_k, 3, %%ymm15\n .endif\n" // A broadcast 3
".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm10\n .endif\n" // 3x1
".if (m_Mr > 3) && (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm11\n .endif\n" // 3x2
".endr\n"
".if m_TransA != 0\n"
" lea 4*4(%%rax), %%rax\n"
".if m_Mr > 2\n lea 4*4(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * 4 * 4(%%rax), %%rax\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * 4 * 4(%%rbx), %%rbx\n"
".else\n"
" lea 8 * 4 * 4(%%rbx), %%rbx\n"
".endif\n"
"sub $4, %%rsi\n"
"cmp $4, %%rsi\n"
"jge L_GemmAvx2_MxN_4x24_K_Loop_Start%=\n"
"testq %%rsi, %%rsi\n"
"je L_GemmAvx2_MxN_4x24_K_Loop_End%=\n"
"L_GemmAvx2_MxN_4x24_K_Loop_Remain%=:\n"
" vload_b%= 0, 0, %%ymm12\n" // B
".if (m_Nr > 8)\n vload_b%= 0, 1, %%ymm13\n .endif\n" // B
".if (m_Nr >16)\n vload_b%= 0, 2, %%ymm14\n .endif\n" // B
" vbroadcast_a%= 0, 0, %%ymm15\n" // A broadcast 0
" vfmadd231ps %%ymm12, %%ymm15, %%ymm0\n" // 0x0
".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm1\n .endif\n" // 0x1
".if (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm2\n .endif\n" // 0x2
".if (m_Mr > 1) \n vbroadcast_a%= 0, 1, %%ymm15\n .endif\n" // A broadcast 1
".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm3\n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm4\n .endif\n" // 1x1
".if (m_Mr > 1) && (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm5\n .endif\n" // 1x2
".if (m_Mr > 2) \n vbroadcast_a%= 0, 2, %%ymm15\n .endif\n" // A broadcast 2
".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 2x1
".if (m_Mr > 2) && (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm8\n .endif\n" // 2x2
".if (m_Mr > 3) \n vbroadcast_a%= 0, 3, %%ymm15\n .endif\n" // A broadcast 3
".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm10\n .endif\n" // 3x1
".if (m_Mr > 3) && (m_Nr >16)\n vfmadd231ps %%ymm14, %%ymm15, %%ymm11\n .endif\n" // 3x2
".if m_TransA != 0\n"
" lea 4(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * 4(%%rax), %%rax\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * 4(%%rbx), %%rbx\n"
".else\n"
" lea 8*4(%%rbx), %%rbx\n"
".endif\n"
"sub $1, %%rsi\n"
"jne L_GemmAvx2_MxN_4x24_K_Loop_Remain%=\n"
"L_GemmAvx2_MxN_4x24_K_Loop_End%=:\n"
"mov 56(%[m_param]), %%eax\n" // alpha
"cmp $0x3f800000, %%eax\n"
"je L_GemmAvx2_MxN_4x24_Update_C%=\n"
"vbroadcastss 56(%[m_param]), %%ymm12\n"
" vmulps %%ymm12, %%ymm0, %%ymm0 \n" // 0x0
".if (m_Nr > 8)\n vmulps %%ymm12, %%ymm1, %%ymm1 \n .endif\n" // 0x1
".if (m_Nr >16)\n vmulps %%ymm12, %%ymm2, %%ymm2 \n .endif\n" // 0x2
".if (m_Mr > 1) \n vmulps %%ymm12, %%ymm3, %%ymm3 \n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm4, %%ymm4 \n .endif\n" // 1x1
".if (m_Mr > 1) && (m_Nr >16)\n vmulps %%ymm12, %%ymm5, %%ymm5 \n .endif\n" // 1x2
".if (m_Mr > 2) \n vmulps %%ymm12, %%ymm6, %%ymm6 \n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm7, %%ymm7 \n .endif\n" // 2x1
".if (m_Mr > 2) && (m_Nr >16)\n vmulps %%ymm12, %%ymm8, %%ymm8 \n .endif\n" // 2x2
".if (m_Mr > 3) \n vmulps %%ymm12, %%ymm9, %%ymm9 \n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm10, %%ymm10\n .endif\n" // 3x1
".if (m_Mr > 3) && (m_Nr >16)\n vmulps %%ymm12, %%ymm11, %%ymm11\n .endif\n" // 3x2
"L_GemmAvx2_MxN_4x24_Update_C%=:\n"
"movq 16(%[m_param]), %%rax\n" // p_c
"movq 48(%[m_param]), %%rdi\n" // ldc
".if (m_Mr > 1)\n lea (%%rax, %%rdi, 1), %%rbx\n .endif\n"
".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n"
".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Nr >16)\n vaddps 64(%%rax), %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)\n vaddps 64(%%rbx), %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)\n vaddps 64(%%rcx), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vaddps 64(%%rdx), %%ymm11, %%ymm11\n .endif\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Nr >16)\n vmovups %%ymm2, 64(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovups %%ymm3, (%%rbx) \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vmovups %%ymm4, 32(%%rbx)\n .endif\n"
".if (m_Mr > 1) && (m_Nr >16)\n vmovups %%ymm5, 64(%%rbx)\n .endif\n"
".if (m_Mr > 2) \n vmovups %%ymm6, (%%rcx) \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vmovups %%ymm7, 32(%%rcx)\n .endif\n"
".if (m_Mr > 2) && (m_Nr >16)\n vmovups %%ymm8, 64(%%rcx)\n .endif\n"
".if (m_Mr > 3) \n vmovups %%ymm9, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovups %%ymm10, 32(%%rdx)\n .endif\n"
".if (m_Mr > 3) && (m_Nr >16)\n vmovups %%ymm11, 64(%%rdx)\n .endif\n"
"L_GemmAvx2_MxN_4x24_Exit%=:\n"
:
:
[m_Mr] "i" (Mr),
[m_Nr] "i" (Nr),
[m_TransA] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? 1 : 0),
[m_TransB] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? 1 : 0),
[m_NTStore] "i" (NonTemporalStore),
[m_ABytes] "i" (sizeof(FloatA)),
[m_BBytes] "i" (sizeof(FloatB)),
[m_CBytes] "i" (sizeof(FloatC)),
[m_param] "r" (param)
:
"cc",
"rax","rbx","rcx","rdx","rsi","rdi","r8",
"ymm0","ymm1","ymm2","ymm3","ymm4","ymm5","ymm6",
"ymm7","ymm8","ymm9","ymm10","ymm11","ymm12","ymm13",
"ymm14","ymm15"
);
// clang-format on
}
};
} // namespace cpu } // namespace cpu
} // namespace ck } // namespace ck
#endif #endif
...@@ -29,6 +29,20 @@ ...@@ -29,6 +29,20 @@
// #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \ // #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT> // ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
#define ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(FA, FB, FC, TA, TB, NT) \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 4, 24, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 3, 24, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 2, 24, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 1, 24, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 4, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 3, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 2, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 1, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 4, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 3, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 2, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_4x24<FA, FB, FC, 1, 8, TA, TB, NT>
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -45,6 +59,17 @@ using thread_gemm_avx2_mxn_6x16_instances = std::tuple< ...@@ -45,6 +59,17 @@ using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format on // clang-format on
>; >;
template <typename ALayout, typename BLayout>
using thread_gemm_avx2_mxn_4x24_instances = std::tuple<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(float, float, float, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(float, float, float, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(float, float, float, ALayout, BLayout, false),
ITERATE_THREAD_GEMM_AVX2_MXN_4X24_INSTANCE(float, float, float, ALayout, BLayout, false)
// clang-format on
>;
void dump_cache_hierarchy() void dump_cache_hierarchy()
{ {
auto dump_cache_type = [&](const ck::cpu::cpuid_cache_type& type) { auto dump_cache_type = [&](const ck::cpu::cpuid_cache_type& type) {
...@@ -336,15 +361,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) ...@@ -336,15 +361,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k); ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k);
using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>; using thread_gemm_instance = thread_gemm_avx2_mxn_6x16_instances<ALayout, BLayout>;
// using thread_gemm_instance = thread_gemm_avx2_mxn_4x24_instances<ALayout, BLayout>;
bool found = false; bool found = false;
ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) { ck::static_for<0, std::tuple_size_v<thread_gemm_instance>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_instance>; using uk_type = std::tuple_element_t<i, thread_gemm_instance>;
// if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value ||
// !std::is_same<typename uk_type::BLayout_, BLayout>::value)
// {
// return;
// }
if(m % uk_type::Mr_ != 0 || n % uk_type::Nr_ != 0) if(m % uk_type::Mr_ != 0 || n % uk_type::Nr_ != 0)
return; return;
if((m != uk_type::Mr_ && std::is_same<typename uk_type::ALayout_, Col>::value) || if((m != uk_type::Mr_ && std::is_same<typename uk_type::ALayout_, Col>::value) ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment