Commit 35f95fe9 authored by carlushuang's avatar carlushuang
Browse files

movaps->movups, and support loop over L1

parent e72c0c43
#ifndef CK_THREADWISE_GEMM_AVX2_HPP
#define CK_THREADWISE_GEMM_AVX2_HPP
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
namespace ck {
namespace cpu {
template <typename FloatA,
typename FloatB,
typename FloatC,
index_t Mr,
index_t Nr,
typename ALayout, // default is k*m, trans->m*k
typename BLayout, // default is n/8*k*n8, trans->k*n
bool NonTemporalStore>
struct ThreadwiseGemmAvx2_MxN_6x16
{
using ALayout_ = ALayout;
using BLayout_ = BLayout;
static constexpr auto Mr_ = Mr;
static constexpr auto Nr_ = Nr;
static constexpr auto NonTemporalStore_ = NonTemporalStore;
__host__ constexpr ThreadwiseGemmAvx2_MxN_6x16()
{
static_assert(Mr <= 6 && Mr >= 1 && (Nr == 8 || Nr == 16), "wrong! Mr x Nr not valid");
}
__host__ static void Run(ThreadwiseGemmParam* param)
{
/* 6x16 ukernel
*
* Mat_B
* |ymm12 |ymm13 |
* Mat_A +--------+--------+
* ymm14 |ymm0 |ymm1 | cycle 0
* ymm15 |ymm2 |ymm3 | cycle 1
* ymm14 |ymm4 |ymm5 | cycle 2
* ymm15 |ymm6 |ymm7 | cycle 3
* ymm14 |ymm8 |ymm9 | cycle 4
* ymm15 |ymm10 |ymm11 | Mat_C cycle 5
*
* ALayout:ColumnMajor (k*m), lda not needed
* ALayout:RowMajor (m*k), lda = k
* BLayout:ColumnMajor (n/8*k*n8), ldb = k*n8. At least this should be 8 continuous n for a
* ymm register BLayout:RowMajor (k*n), ldb not needed
*
* lda/ldb/ldc all in unit of byte
*
*/
// clang-format off
__asm__ __volatile__ (
"L_GemmAvx2_MxN_6x16_Entry%=:\n"
".set m_Mr, %c[m_Mr]\n"
".set m_Nr, %c[m_Nr]\n"
".set m_TransA, %c[m_TransA]\n"
".set m_TransB, %c[m_TransB]\n"
".set m_NTStore, %c[m_NTStore]\n"
".set m_ABytes, %c[m_ABytes]\n"
".set m_BBytes, %c[m_BBytes]\n"
".set m_CBytes, %c[m_CBytes]\n"
"movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda
".endif\n"
".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
"vbroadcastss \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n"
"vbroadcastss \\i_offset(\\r_base), \\ymm\n"
".endif\n"
".endm\n"
".macro vmovaps_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
"vmovaps \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n"
"vmovaps \\i_offset(\\r_base), \\ymm\n"
".endif\n"
".endm\n"
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8, r9), lda in rcx
".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * 4, \\ymm\n"
".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, \\i_k * 4, \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, \\i_m-3, \\i_k * 4, \\ymm\n"
".endif\n"
".endif\n"
".endm\n"
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4\n"
".if m_TransB == 0\n"
"vmovaps_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n"
"vmovaps_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n"
".else\n"
".if m_TransB == 0\n"
"vmovaps_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n"
"vmovaps_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n"
".endif\n"
".endm\n"
" vxorps %%ymm0, %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vxorps %%ymm1, %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vxorps %%ymm2, %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vxorps %%ymm3, %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 2) \n vxorps %%ymm4, %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vxorps %%ymm5, %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 3) \n vxorps %%ymm6, %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vxorps %%ymm7, %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 4) \n vxorps %%ymm8, %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vxorps %%ymm9, %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 5) \n vxorps %%ymm10, %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vxorps %%ymm11, %%ymm11, %%ymm11\n .endif\n"
".if m_TransA != 0\n"
".if m_Mr > 3\n"
"lea (%%rcx, %%rcx, 2), %%r9\n"
"lea (%%rax, %%r9), %%r8\n"
".endif\n"
".endif\n"
"cmp $4, %%rsi\n"
"jl L_GemmAvx2_MxN_6x16_K_Loop_Remain%=\n"
"L_GemmAvx2_MxN_6x16_K_Loop_Start%=:\n"
".irp i_k, 0, 1, 2, 3\n"
" vload_b%= \\i_k, 0, %%ymm12\n" // B
".if (m_Nr > 8)\n vload_b%= \\i_k, 1, %%ymm13\n .endif\n" // B
" vbroadcast_a%= \\i_k, 0, %%ymm14\n" // A broadcast 0
".if (m_Mr > 1) \n vbroadcast_a%= \\i_k, 1, %%ymm15\n .endif\n" // A broadcast 1
" vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" // 0x0
".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n .endif\n" // 0x1
".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n .endif\n" // 1x1
".if (m_Mr > 2) \n vbroadcast_a%= \\i_k, 2, %%ymm14\n .endif\n" // A broadcast 2
".if (m_Mr > 3) \n vbroadcast_a%= \\i_k, 3, %%ymm15\n .endif\n" // A broadcast 3
".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n .endif\n" // 2x1
".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 3x1
".if (m_Mr > 4) \n vbroadcast_a%= \\i_k, 4, %%ymm14\n .endif\n" // A broadcast 4
".if (m_Mr > 5) \n vbroadcast_a%= \\i_k, 5, %%ymm15\n .endif\n" // A broadcast 5
".if (m_Mr > 4) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n .endif\n" // 4x0
".if (m_Mr > 4) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n .endif\n" // 4x1
".if (m_Mr > 5) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n .endif\n" // 5x0
".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1
".endr\n"
".if m_TransA != 0\n"
" lea 4*4(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4*4(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * 4 * 4(%%rax), %%rax\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * 4 * 4(%%rbx), %%rbx\n"
".else\n"
" lea 8 * 4 * 4(%%rbx), %%rbx\n"
".endif\n"
"sub $4, %%rsi\n"
"cmp $4, %%rsi\n"
"jge L_GemmAvx2_MxN_6x16_K_Loop_Start%=\n"
"testq %%rsi, %%rsi\n"
"je L_GemmAvx2_MxN_6x16_K_Loop_End%=\n"
"L_GemmAvx2_MxN_6x16_K_Loop_Remain%=:\n"
" vload_b%= 0, 0, %%ymm12\n" // B
".if (m_Nr > 8)\n vload_b%= 0, 1, %%ymm13\n .endif\n" // B
" vbroadcast_a%= 0, 0, %%ymm14\n" // A broadcast 0
".if (m_Mr > 1) \n vbroadcast_a%= 0, 1, %%ymm15\n .endif\n" // A broadcast 1
" vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" // 0x0
".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n .endif\n" // 0x1
".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n .endif\n" // 1x1
".if (m_Mr > 2) \n vbroadcast_a%= 0, 2, %%ymm14\n .endif\n" // A broadcast 2
".if (m_Mr > 3) \n vbroadcast_a%= 0, 3, %%ymm15\n .endif\n" // A broadcast 3
".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n .endif\n" // 2x1
".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 3x1
".if (m_Mr > 4) \n vbroadcast_a%= 0, 4, %%ymm14\n .endif\n" // A broadcast 4
".if (m_Mr > 5) \n vbroadcast_a%= 0, 5, %%ymm15\n .endif\n" // A broadcast 5
".if (m_Mr > 4) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n .endif\n" // 4x0
".if (m_Mr > 4) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n .endif\n" // 4x1
".if (m_Mr > 5) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n .endif\n" // 5x0
".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1
".if m_TransA != 0\n"
" lea 4(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * 4(%%rax), %%rax\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * 4(%%rbx), %%rbx\n"
".else\n"
" lea 8*4(%%rbx), %%rbx\n"
".endif\n"
"sub $1, %%rsi\n"
"jne L_GemmAvx2_MxN_6x16_K_Loop_Remain%=\n"
"L_GemmAvx2_MxN_6x16_K_Loop_End%=:\n"
"mov 56(%[m_param]), %%eax\n" // alpha
"cmp $0x3f800000, %%eax\n"
"je L_GemmAvx2_MxN_6x16_Update_C%=\n"
"vbroadcastss 56(%[m_param]), %%ymm12\n"
" vmulps %%ymm12, %%ymm0, %%ymm0 \n" // 0x0
".if (m_Nr > 8)\n vmulps %%ymm12, %%ymm1, %%ymm1 \n .endif\n" // 0x1
".if (m_Mr > 1) \n vmulps %%ymm12, %%ymm2, %%ymm2 \n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm3, %%ymm3 \n .endif\n" // 1x1
".if (m_Mr > 2) \n vmulps %%ymm12, %%ymm4, %%ymm4 \n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm5, %%ymm5 \n .endif\n" // 2x1
".if (m_Mr > 3) \n vmulps %%ymm12, %%ymm6, %%ymm6 \n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm7, %%ymm7 \n .endif\n" // 3x1
".if (m_Mr > 4) \n vmulps %%ymm12, %%ymm8, %%ymm8 \n .endif\n" // 4x0
".if (m_Mr > 4) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm9, %%ymm9 \n .endif\n" // 4x1
".if (m_Mr > 5) \n vmulps %%ymm12, %%ymm10, %%ymm10\n .endif\n" // 5x0
".if (m_Mr > 5) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm11, %%ymm11\n .endif\n" // 5x1
"L_GemmAvx2_MxN_6x16_Update_C%=:\n"
"movq 16(%[m_param]), %%rax\n" // p_c
"movq 48(%[m_param]), %%rdi\n" // ldc
".if (m_Mr > 1)\n lea (%%rax, %%rdi, 1), %%rbx\n .endif\n"
".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n"
".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n"
".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n"
".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 4) \n vaddps (%%r8), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vaddps 32(%%r8), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
" vmovaps %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovaps %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovaps %%ymm2, (%%rbx) \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vmovaps %%ymm3, 32(%%rbx)\n .endif\n"
".if (m_Mr > 2) \n vmovaps %%ymm4, (%%rcx) \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vmovaps %%ymm5, 32(%%rcx)\n .endif\n"
".if (m_Mr > 3) \n vmovaps %%ymm6, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovaps %%ymm7, 32(%%rdx)\n .endif\n"
".if (m_Mr > 4) \n vmovaps %%ymm8, (%%r8) \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vmovaps %%ymm9, 32(%%r8) \n .endif\n"
".if (m_Mr > 5) \n vmovaps %%ymm10, (%%r9) \n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vmovaps %%ymm11, 32(%%r9) \n .endif\n"
"L_GemmAvx2_MxN_6x16_Exit%=:\n"
:
:
[m_Mr] "i" (Mr),
[m_Nr] "i" (Nr),
[m_TransA] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? 1 : 0),
[m_TransB] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? 1 : 0),
[m_NTStore] "i" (NonTemporalStore),
[m_ABytes] "i" (sizeof(FloatA)),
[m_BBytes] "i" (sizeof(FloatB)),
[m_CBytes] "i" (sizeof(FloatC)),
[m_param] "r" (param)
:
"cc",
"rax","rbx","rcx","rdx","rsi","rdi","r8","r9",
"ymm0","ymm1","ymm2","ymm3","ymm4","ymm5","ymm6",
"ymm7","ymm8","ymm9","ymm10","ymm11","ymm12","ymm13",
"ymm14","ymm15"
);
// clang-format on
}
};
} // namespace cpu
} // namespace ck
#endif
#ifndef CK_THREADWISE_GEMM_AVX2_HPP
#define CK_THREADWISE_GEMM_AVX2_HPP
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "math.hpp"
#include "threadwise_param.hpp"
namespace ck {
namespace cpu {
template <typename FloatA,
typename FloatB,
typename FloatC,
index_t Mr,
index_t Nr,
typename ALayout, // default is k*m, trans->m*k
typename BLayout, // default is n/8*k*n8, trans->k*n
bool NonTemporalStore>
struct ThreadwiseGemmAvx2_MxN_6x16
{
using ALayout_ = ALayout;
using BLayout_ = BLayout;
static constexpr auto Mr_ = Mr;
static constexpr auto Nr_ = Nr;
static constexpr auto NonTemporalStore_ = NonTemporalStore;
__host__ constexpr ThreadwiseGemmAvx2_MxN_6x16()
{
static_assert(Mr <= 6 && Mr >= 1 && (Nr == 8 || Nr == 16), "wrong! Mr x Nr not valid");
}
__host__ static void Run(ThreadwiseGemmParam* param)
{
/* 6x16 ukernel
*
* Mat_B
* |ymm12 |ymm13 |
* Mat_A +--------+--------+
* ymm14 |ymm0 |ymm1 | cycle 0
* ymm15 |ymm2 |ymm3 | cycle 1
* ymm14 |ymm4 |ymm5 | cycle 2
* ymm15 |ymm6 |ymm7 | cycle 3
* ymm14 |ymm8 |ymm9 | cycle 4
* ymm15 |ymm10 |ymm11 | Mat_C cycle 5
*
* ALayout:ColumnMajor (k*m), lda not needed
* ALayout:RowMajor (m*k), lda = k
* BLayout:ColumnMajor (n/8*k*n8), ldb = k*n8. At least this should be 8 continuous n for a
* ymm register BLayout:RowMajor (k*n), ldb not needed
*
* lda/ldb/ldc all in unit of byte
*
*/
// clang-format off
__asm__ __volatile__ (
"L_GemmAvx2_MxN_6x16_Entry%=:\n"
".set m_Mr, %c[m_Mr]\n"
".set m_Nr, %c[m_Nr]\n"
".set m_TransA, %c[m_TransA]\n"
".set m_TransB, %c[m_TransB]\n"
".set m_NTStore, %c[m_NTStore]\n"
".set m_ABytes, %c[m_ABytes]\n"
".set m_BBytes, %c[m_BBytes]\n"
".set m_CBytes, %c[m_CBytes]\n"
"movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda
".endif\n"
".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
"vbroadcastss \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n"
"vbroadcastss \\i_offset(\\r_base), \\ymm\n"
".endif\n"
".endm\n"
".macro vmovups_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n"
"vmovups \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n"
"vmovups \\i_offset(\\r_base), \\ymm\n"
".endif\n"
".endm\n"
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8, r9), lda in rcx
".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * 4, \\ymm\n"
".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, \\i_k * 4, \\ymm\n"
".else\n"
"vbroadcastss_%= %%r8, %%rcx, \\i_m-3, \\i_k * 4, \\ymm\n"
".endif\n"
".endif\n"
".endm\n"
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4\n"
".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n"
"vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n"
".else\n"
".if m_TransB == 0\n"
"vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n"
"vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n"
".endif\n"
".endm\n"
" vxorps %%ymm0, %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vxorps %%ymm1, %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vxorps %%ymm2, %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vxorps %%ymm3, %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 2) \n vxorps %%ymm4, %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vxorps %%ymm5, %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 3) \n vxorps %%ymm6, %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vxorps %%ymm7, %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 4) \n vxorps %%ymm8, %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vxorps %%ymm9, %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 5) \n vxorps %%ymm10, %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vxorps %%ymm11, %%ymm11, %%ymm11\n .endif\n"
".if m_TransA != 0\n"
".if m_Mr > 3\n"
"lea (%%rcx, %%rcx, 2), %%r9\n"
"lea (%%rax, %%r9), %%r8\n"
".endif\n"
".endif\n"
"cmp $4, %%rsi\n"
"jl L_GemmAvx2_MxN_6x16_K_Loop_Remain%=\n"
"L_GemmAvx2_MxN_6x16_K_Loop_Start%=:\n"
".irp i_k, 0, 1, 2, 3\n"
" vload_b%= \\i_k, 0, %%ymm12\n" // B
".if (m_Nr > 8)\n vload_b%= \\i_k, 1, %%ymm13\n .endif\n" // B
" vbroadcast_a%= \\i_k, 0, %%ymm14\n" // A broadcast 0
".if (m_Mr > 1) \n vbroadcast_a%= \\i_k, 1, %%ymm15\n .endif\n" // A broadcast 1
" vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" // 0x0
".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n .endif\n" // 0x1
".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n .endif\n" // 1x1
".if (m_Mr > 2) \n vbroadcast_a%= \\i_k, 2, %%ymm14\n .endif\n" // A broadcast 2
".if (m_Mr > 3) \n vbroadcast_a%= \\i_k, 3, %%ymm15\n .endif\n" // A broadcast 3
".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n .endif\n" // 2x1
".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 3x1
".if (m_Mr > 4) \n vbroadcast_a%= \\i_k, 4, %%ymm14\n .endif\n" // A broadcast 4
".if (m_Mr > 5) \n vbroadcast_a%= \\i_k, 5, %%ymm15\n .endif\n" // A broadcast 5
".if (m_Mr > 4) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n .endif\n" // 4x0
".if (m_Mr > 4) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n .endif\n" // 4x1
".if (m_Mr > 5) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n .endif\n" // 5x0
".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1
".endr\n"
".if m_TransA != 0\n"
" lea 4*4(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4*4(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * 4 * 4(%%rax), %%rax\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * 4 * 4(%%rbx), %%rbx\n"
".else\n"
" lea 8 * 4 * 4(%%rbx), %%rbx\n"
".endif\n"
"sub $4, %%rsi\n"
"cmp $4, %%rsi\n"
"jge L_GemmAvx2_MxN_6x16_K_Loop_Start%=\n"
"testq %%rsi, %%rsi\n"
"je L_GemmAvx2_MxN_6x16_K_Loop_End%=\n"
"L_GemmAvx2_MxN_6x16_K_Loop_Remain%=:\n"
" vload_b%= 0, 0, %%ymm12\n" // B
".if (m_Nr > 8)\n vload_b%= 0, 1, %%ymm13\n .endif\n" // B
" vbroadcast_a%= 0, 0, %%ymm14\n" // A broadcast 0
".if (m_Mr > 1) \n vbroadcast_a%= 0, 1, %%ymm15\n .endif\n" // A broadcast 1
" vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" // 0x0
".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n .endif\n" // 0x1
".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n .endif\n" // 1x1
".if (m_Mr > 2) \n vbroadcast_a%= 0, 2, %%ymm14\n .endif\n" // A broadcast 2
".if (m_Mr > 3) \n vbroadcast_a%= 0, 3, %%ymm15\n .endif\n" // A broadcast 3
".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n .endif\n" // 2x1
".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 3x1
".if (m_Mr > 4) \n vbroadcast_a%= 0, 4, %%ymm14\n .endif\n" // A broadcast 4
".if (m_Mr > 5) \n vbroadcast_a%= 0, 5, %%ymm15\n .endif\n" // A broadcast 5
".if (m_Mr > 4) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n .endif\n" // 4x0
".if (m_Mr > 4) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n .endif\n" // 4x1
".if (m_Mr > 5) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n .endif\n" // 5x0
".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1
".if m_TransA != 0\n"
" lea 4(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4(%%r8), %%r8\n .endif\n"
".else\n"
" lea m_Mr * 4(%%rax), %%rax\n"
".endif\n"
".if m_TransB != 0\n"
" lea m_Nr * 4(%%rbx), %%rbx\n"
".else\n"
" lea 8*4(%%rbx), %%rbx\n"
".endif\n"
"sub $1, %%rsi\n"
"jne L_GemmAvx2_MxN_6x16_K_Loop_Remain%=\n"
"L_GemmAvx2_MxN_6x16_K_Loop_End%=:\n"
"mov 56(%[m_param]), %%eax\n" // alpha
"cmp $0x3f800000, %%eax\n"
"je L_GemmAvx2_MxN_6x16_Update_C%=\n"
"vbroadcastss 56(%[m_param]), %%ymm12\n"
" vmulps %%ymm12, %%ymm0, %%ymm0 \n" // 0x0
".if (m_Nr > 8)\n vmulps %%ymm12, %%ymm1, %%ymm1 \n .endif\n" // 0x1
".if (m_Mr > 1) \n vmulps %%ymm12, %%ymm2, %%ymm2 \n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm3, %%ymm3 \n .endif\n" // 1x1
".if (m_Mr > 2) \n vmulps %%ymm12, %%ymm4, %%ymm4 \n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm5, %%ymm5 \n .endif\n" // 2x1
".if (m_Mr > 3) \n vmulps %%ymm12, %%ymm6, %%ymm6 \n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm7, %%ymm7 \n .endif\n" // 3x1
".if (m_Mr > 4) \n vmulps %%ymm12, %%ymm8, %%ymm8 \n .endif\n" // 4x0
".if (m_Mr > 4) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm9, %%ymm9 \n .endif\n" // 4x1
".if (m_Mr > 5) \n vmulps %%ymm12, %%ymm10, %%ymm10\n .endif\n" // 5x0
".if (m_Mr > 5) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm11, %%ymm11\n .endif\n" // 5x1
"L_GemmAvx2_MxN_6x16_Update_C%=:\n"
"movq 16(%[m_param]), %%rax\n" // p_c
"movq 48(%[m_param]), %%rdi\n" // ldc
".if (m_Mr > 1)\n lea (%%rax, %%rdi, 1), %%rbx\n .endif\n"
".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n"
".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n"
".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n"
".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 4) \n vaddps (%%r8), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vaddps 32(%%r8), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
" vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovups %%ymm2, (%%rbx) \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vmovups %%ymm3, 32(%%rbx)\n .endif\n"
".if (m_Mr > 2) \n vmovups %%ymm4, (%%rcx) \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vmovups %%ymm5, 32(%%rcx)\n .endif\n"
".if (m_Mr > 3) \n vmovups %%ymm6, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovups %%ymm7, 32(%%rdx)\n .endif\n"
".if (m_Mr > 4) \n vmovups %%ymm8, (%%r8) \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vmovups %%ymm9, 32(%%r8) \n .endif\n"
".if (m_Mr > 5) \n vmovups %%ymm10, (%%r9) \n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vmovups %%ymm11, 32(%%r9) \n .endif\n"
"L_GemmAvx2_MxN_6x16_Exit%=:\n"
:
:
[m_Mr] "i" (Mr),
[m_Nr] "i" (Nr),
[m_TransA] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? 1 : 0),
[m_TransB] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? 1 : 0),
[m_NTStore] "i" (NonTemporalStore),
[m_ABytes] "i" (sizeof(FloatA)),
[m_BBytes] "i" (sizeof(FloatB)),
[m_CBytes] "i" (sizeof(FloatC)),
[m_param] "r" (param)
:
"cc",
"rax","rbx","rcx","rdx","rsi","rdi","r8","r9",
"ymm0","ymm1","ymm2","ymm3","ymm4","ymm5","ymm6",
"ymm7","ymm8","ymm9","ymm10","ymm11","ymm12","ymm13",
"ymm14","ymm15"
);
// clang-format on
}
};
} // namespace cpu
} // namespace ck
#endif
#ifndef CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace ck {
namespace cpu {
struct ThreadwiseGemmParam
{
const void* p_a;
const void* p_b;
void* p_c;
uint64_t Kr;
uint64_t lda; // in unit of byte
uint64_t ldb; // in unit of byte
uint64_t ldc; // in unit of byte
float alpha;
uint32_t _pack0;
} __attribute__((packed));
} // namespace cpu
} // namespace ck
#endif
#ifndef CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace ck {
namespace cpu {
struct ThreadwiseGemmParam
{
const void* p_a;
const void* p_b;
void* p_c;
uint64_t Kr;
uint64_t lda; // in unit of byte
uint64_t ldb; // in unit of byte
uint64_t ldc; // in unit of byte
float alpha;
uint32_t _pack0;
} __attribute__((packed));
} // namespace cpu
} // namespace ck
#endif
#ifndef CK_CPUID_HPP
#define CK_CPUID_HPP
namespace ck {
namespace cpu {
enum cpuid_vendor
{
cpuid_vendor_intel = 0,
cpuid_vendor_amd = 1,
cpuid_vendor_other = 2,
};
enum cpuid_cache_type
{
cpuid_cache_type_null = 0,
cpuid_cache_type_dcache = 1,
cpuid_cache_type_icache = 2,
cpuid_cache_type_unified = 3,
};
struct cpuid_raw
{
uint32_t eax{0};
uint32_t ebx{0};
uint32_t ecx{0};
uint32_t edx{0};
};
struct cpuid_cache_detail
{
uint32_t size{0};
uint32_t type{0};
uint32_t cache_line_size{0};
uint32_t associativity{0};
uint32_t sets{0};
uint32_t partitions{0};
uint32_t shared_by_procs{0}; // in HT, usually maybe 2 threads per core, hence for L1/L2,
// usually this maybe 2, unless turn of HT
uint32_t cores_per_socket{0}; // hardware cores in a physical socket. there maybe multiple
// sockets on the chip. TODO: may not needed?
uint32_t flags{0};
};
struct cpuid_cache_hierarchy
{
cpuid_cache_detail l1i;
cpuid_cache_detail l1d;
cpuid_cache_detail l2;
cpuid_cache_detail l3;
cpuid_cache_detail l4;
};
static inline cpuid_raw cpuid(uint32_t eax, uint32_t ecx)
{
// some leaf feature require ecx value.
// for others, ecx actually not used.
uint32_t ebx, edx;
asm __volatile__("mov %0, %%eax\n"
"mov %2, %%ecx\n"
"cpuid\n"
"mov %%eax, %0\n"
"mov %%ebx, %1\n"
"mov %%ecx, %2\n"
"mov %%edx, %3\n"
: "=r"(eax), "=r"(ebx), "=r"(ecx), "=r"(edx)
: "0"(eax), "2"(ecx));
return {eax, ebx, ecx, edx};
}
static inline cpuid_vendor cpuid_query_vendor()
{
cpuid_raw r = cpuid(0, 0);
if(r.ebx == 0x756E6547U /*Genu*/ && r.edx == 0x49656E69U /*ineI*/ &&
r.ecx == 0x6C65746EU /*ntel*/)
{
return cpuid_vendor_intel;
}
if(r.ebx == 0x68747541U /*Auth*/ && r.edx == 0x74656273U /*enti*/ &&
r.ecx == 0x444D4163U /*cAMD*/)
{
return cpuid_vendor_amd;
}
if(r.ebx == 0x69444D41U /*AMDi*/ && r.edx == 0x69746E65U /*sbet*/ &&
r.ecx == 0x21726574U /*ter */)
{
return cpuid_vendor_amd;
}
return cpuid_vendor_other;
}
static inline cpuid_cache_hierarchy cpuid_query_cache()
{
cpuid_cache_hierarchy cache_hierarchy;
cpuid_vendor vendor = cpuid_query_vendor();
uint32_t leaf_cache_id = vendor == cpuid_vendor_amd ? 0x8000001d : 0x4;
for(uint32_t ecx_idx = 0;; ecx_idx++)
{
cpuid_raw r = cpuid(leaf_cache_id, ecx_idx);
uint32_t cache_type = r.eax & 0x1f;
if(cache_type == cpuid_cache_type_null)
break; // Null, no more cache
uint32_t cache_level = (r.eax >> 5) & 0x7;
uint32_t cache_shared_by_cores = 1 + ((r.eax >> 14) & 0xfff);
uint32_t cache_lpp_cores = 1 + ((r.eax >> 26) & 0x3f);
uint32_t cache_line_size = 1 + (r.ebx & 0xfff);
uint32_t cache_partitions = 1 + ((r.ebx >> 12) & 0x3ff);
uint32_t cache_associativity = 1 + (r.ebx >> 22);
uint32_t cache_sets = 1 + r.ecx;
switch(cache_level)
{
case 1:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{
cache_hierarchy.l1d.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l1d.type = cache_type;
cache_hierarchy.l1d.cache_line_size = cache_line_size;
cache_hierarchy.l1d.associativity = cache_associativity;
cache_hierarchy.l1d.sets = cache_sets;
cache_hierarchy.l1d.partitions = cache_partitions;
cache_hierarchy.l1d.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l1d.cores_per_socket = cache_lpp_cores;
}
else if(cache_type == cpuid_cache_type_icache)
{
cache_hierarchy.l1i.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l1i.type = cache_type;
cache_hierarchy.l1i.cache_line_size = cache_line_size;
cache_hierarchy.l1i.associativity = cache_associativity;
cache_hierarchy.l1i.sets = cache_sets;
cache_hierarchy.l1i.partitions = cache_partitions;
cache_hierarchy.l1i.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l1i.cores_per_socket = cache_lpp_cores;
}
break;
case 2:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{
cache_hierarchy.l2.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l2.type = cache_type;
cache_hierarchy.l2.cache_line_size = cache_line_size;
cache_hierarchy.l2.associativity = cache_associativity;
cache_hierarchy.l2.sets = cache_sets;
cache_hierarchy.l2.partitions = cache_partitions;
cache_hierarchy.l2.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l2.cores_per_socket = cache_lpp_cores;
}
break;
case 3:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{
cache_hierarchy.l3.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l3.type = cache_type;
cache_hierarchy.l3.cache_line_size = cache_line_size;
cache_hierarchy.l3.associativity = cache_associativity;
cache_hierarchy.l3.sets = cache_sets;
cache_hierarchy.l3.partitions = cache_partitions;
cache_hierarchy.l3.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l3.cores_per_socket = cache_lpp_cores;
}
break;
case 4:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{
cache_hierarchy.l4.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l4.type = cache_type;
cache_hierarchy.l4.cache_line_size = cache_line_size;
cache_hierarchy.l4.associativity = cache_associativity;
cache_hierarchy.l4.sets = cache_sets;
cache_hierarchy.l4.partitions = cache_partitions;
cache_hierarchy.l4.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l4.cores_per_socket = cache_lpp_cores;
}
break;
}
}
return cache_hierarchy;
}
} // namespace cpu
} // namespace ck
#endif
#ifndef CK_CPUID_HPP
#define CK_CPUID_HPP
namespace ck {
namespace cpu {
enum cpuid_vendor
{
cpuid_vendor_intel = 0,
cpuid_vendor_amd = 1,
cpuid_vendor_other = 2,
};
enum cpuid_cache_type
{
cpuid_cache_type_null = 0,
cpuid_cache_type_dcache = 1,
cpuid_cache_type_icache = 2,
cpuid_cache_type_unified = 3,
};
struct cpuid_raw
{
uint32_t eax{0};
uint32_t ebx{0};
uint32_t ecx{0};
uint32_t edx{0};
};
struct cpuid_cache_detail
{
uint32_t size{0};
uint32_t type{0};
uint32_t cache_line_size{0};
uint32_t associativity{0};
uint32_t sets{0};
uint32_t partitions{0};
uint32_t shared_by_procs{0}; // in HT, usually maybe 2 threads per core, hence for L1/L2,
// usually this maybe 2, unless turn of HT
uint32_t cores_per_socket{0}; // hardware cores in a physical socket. there maybe multiple
// sockets on the chip. TODO: may not needed?
uint32_t flags{0};
};
struct cpuid_cache_hierarchy
{
cpuid_cache_detail l1i;
cpuid_cache_detail l1d;
cpuid_cache_detail l2;
cpuid_cache_detail l3;
cpuid_cache_detail l4;
};
static inline cpuid_raw cpuid(uint32_t eax, uint32_t ecx)
{
// some leaf feature require ecx value.
// for others, ecx actually not used.
uint32_t ebx, edx;
asm __volatile__("mov %0, %%eax\n"
"mov %2, %%ecx\n"
"cpuid\n"
"mov %%eax, %0\n"
"mov %%ebx, %1\n"
"mov %%ecx, %2\n"
"mov %%edx, %3\n"
: "=r"(eax), "=r"(ebx), "=r"(ecx), "=r"(edx)
: "0"(eax), "2"(ecx));
return {eax, ebx, ecx, edx};
}
static inline cpuid_vendor cpuid_query_vendor()
{
cpuid_raw r = cpuid(0, 0);
if(r.ebx == 0x756E6547U /*Genu*/ && r.edx == 0x49656E69U /*ineI*/ &&
r.ecx == 0x6C65746EU /*ntel*/)
{
return cpuid_vendor_intel;
}
if(r.ebx == 0x68747541U /*Auth*/ && r.edx == 0x74656273U /*enti*/ &&
r.ecx == 0x444D4163U /*cAMD*/)
{
return cpuid_vendor_amd;
}
if(r.ebx == 0x69444D41U /*AMDi*/ && r.edx == 0x69746E65U /*sbet*/ &&
r.ecx == 0x21726574U /*ter */)
{
return cpuid_vendor_amd;
}
return cpuid_vendor_other;
}
static inline cpuid_cache_hierarchy cpuid_query_cache()
{
cpuid_cache_hierarchy cache_hierarchy;
cpuid_vendor vendor = cpuid_query_vendor();
uint32_t leaf_cache_id = vendor == cpuid_vendor_amd ? 0x8000001d : 0x4;
for(uint32_t ecx_idx = 0;; ecx_idx++)
{
cpuid_raw r = cpuid(leaf_cache_id, ecx_idx);
uint32_t cache_type = r.eax & 0x1f;
if(cache_type == cpuid_cache_type_null)
break; // Null, no more cache
uint32_t cache_level = (r.eax >> 5) & 0x7;
uint32_t cache_shared_by_cores = 1 + ((r.eax >> 14) & 0xfff);
uint32_t cache_lpp_cores = 1 + ((r.eax >> 26) & 0x3f);
uint32_t cache_line_size = 1 + (r.ebx & 0xfff);
uint32_t cache_partitions = 1 + ((r.ebx >> 12) & 0x3ff);
uint32_t cache_associativity = 1 + (r.ebx >> 22);
uint32_t cache_sets = 1 + r.ecx;
switch(cache_level)
{
case 1:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{
cache_hierarchy.l1d.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l1d.type = cache_type;
cache_hierarchy.l1d.cache_line_size = cache_line_size;
cache_hierarchy.l1d.associativity = cache_associativity;
cache_hierarchy.l1d.sets = cache_sets;
cache_hierarchy.l1d.partitions = cache_partitions;
cache_hierarchy.l1d.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l1d.cores_per_socket = cache_lpp_cores;
}
else if(cache_type == cpuid_cache_type_icache)
{
cache_hierarchy.l1i.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l1i.type = cache_type;
cache_hierarchy.l1i.cache_line_size = cache_line_size;
cache_hierarchy.l1i.associativity = cache_associativity;
cache_hierarchy.l1i.sets = cache_sets;
cache_hierarchy.l1i.partitions = cache_partitions;
cache_hierarchy.l1i.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l1i.cores_per_socket = cache_lpp_cores;
}
break;
case 2:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{
cache_hierarchy.l2.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l2.type = cache_type;
cache_hierarchy.l2.cache_line_size = cache_line_size;
cache_hierarchy.l2.associativity = cache_associativity;
cache_hierarchy.l2.sets = cache_sets;
cache_hierarchy.l2.partitions = cache_partitions;
cache_hierarchy.l2.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l2.cores_per_socket = cache_lpp_cores;
}
break;
case 3:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{
cache_hierarchy.l3.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l3.type = cache_type;
cache_hierarchy.l3.cache_line_size = cache_line_size;
cache_hierarchy.l3.associativity = cache_associativity;
cache_hierarchy.l3.sets = cache_sets;
cache_hierarchy.l3.partitions = cache_partitions;
cache_hierarchy.l3.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l3.cores_per_socket = cache_lpp_cores;
}
break;
case 4:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{
cache_hierarchy.l4.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l4.type = cache_type;
cache_hierarchy.l4.cache_line_size = cache_line_size;
cache_hierarchy.l4.associativity = cache_associativity;
cache_hierarchy.l4.sets = cache_sets;
cache_hierarchy.l4.partitions = cache_partitions;
cache_hierarchy.l4.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l4.cores_per_socket = cache_lpp_cores;
}
break;
}
}
return cache_hierarchy;
}
} // namespace cpu
} // namespace ck
#endif
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <string>
#include <sstream>
#include <tuple>
#include <memory>
#include <chrono>
#include "config.hpp"
#include "print.hpp"
#include "cpuid.hpp"
#include "threadwise_gemm_avx2.hpp"
#define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 5, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 4, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 3, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 5, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 4, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 3, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 8, TA, TB, NT>
// #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Row, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Row, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Col, false)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false)
// clang-format on
>;
void dump_cache_hierarchy()
{
auto dump_cache_type = [&](const ck::cpu::cpuid_cache_type& type) {
if(type == ck::cpu::cpuid_cache_type_dcache)
printf("data cache");
else if(type == ck::cpu::cpuid_cache_type_icache)
printf("inst cache");
else if(type == ck::cpu::cpuid_cache_type_unified)
printf("unif cache");
};
auto dump_cache_detail = [&](const ck::cpu::cpuid_cache_detail& detail) {
dump_cache_type(static_cast<const ck::cpu::cpuid_cache_type>(detail.type));
printf(" size:%u, cache_line:%u, associativity:%u, sets:%u, partitions:%u, shared by "
"procs:%u(%u)\n",
detail.size,
detail.cache_line_size,
detail.associativity,
detail.sets,
detail.partitions,
detail.shared_by_procs,
detail.cores_per_socket);
};
ck::cpu::cpuid_cache_hierarchy cache = ck::cpu::cpuid_query_cache();
if(cache.l1d.size != 0)
{
printf("l1 ");
dump_cache_detail(cache.l1d);
}
if(cache.l1i.size != 0)
{
printf("l1 ");
dump_cache_detail(cache.l1i);
}
if(cache.l2.size != 0)
{
printf("l2 ");
dump_cache_detail(cache.l2);
}
if(cache.l3.size != 0)
{
printf("l3 ");
dump_cache_detail(cache.l3);
}
if(cache.l4.size != 0)
{
printf("l4 ");
dump_cache_detail(cache.l4);
}
}
void* __aligned_malloc(size_t required_bytes, size_t alignment)
{
if(alignment == 0 || (alignment & (alignment - 1))) // check pow of 2
return nullptr;
void* p1; // original block
void** p2; // aligned block
int offset = alignment - 1 + sizeof(void*);
if((p1 = malloc(required_bytes + offset)) == nullptr)
{
return nullptr;
}
p2 = reinterpret_cast<void**>((reinterpret_cast<size_t>(p1) + offset) & ~(alignment - 1));
p2[-1] = p1;
return p2;
}
void __aligned_free(void* p) { free((reinterpret_cast<void**>(p))[-1]); }
template <typename T>
void rand_vector(T* v, int elem)
{
int i;
static int flag = 0;
if(!flag)
{
srand(time(nullptr));
flag = 1;
}
for(i = 0; i < elem; i++)
{
v[i] = (static_cast<T>(rand() % 100)) / 100.0f;
}
}
bool valid_vector(const float* ref, const float* rhs, uint32_t elem)
{
float rtol = 1e-5;
float atol = 1e-8;
uint32_t err = 0;
for(uint32_t i = 0; i < elem; i++)
{
float diff = std::abs(ref[i] - rhs[i]);
if(diff > atol + rtol * std::abs(ref[i]))
{
printf("diff at %u, ref:%f, rhs:%f\n", i, ref[i], rhs[i]);
err++;
}
}
return err == 0;
}
template <typename data_type, typename ALayout, typename BLayout>
void ref_cpu_gemm_uk(const data_type* a,
const data_type* b,
float* c,
float alpha,
uint32_t m,
uint32_t n,
uint32_t k)
{
auto a_offset = [&](uint32_t im, uint32_t ik) {
if constexpr(std::is_same<Row, ALayout>::value)
{
return im * k + ik;
}
else
{
return ik * m + im;
}
};
auto b_offset = [&](uint32_t ik, uint32_t in) {
if constexpr(std::is_same<Row, BLayout>::value)
{
return ik * n + in;
}
else
{
// n*k*n8
return (in / 8) * k * 8 + ik * 8 + in % 8;
}
};
auto c_offset = [&](uint32_t im, uint32_t in) { return im * n + in; };
for(uint32_t im = 0; im < m; im++)
{
for(uint32_t in = 0; in < n; in++)
{
float acc = .0f;
for(uint32_t ik = 0; ik < k; ik++)
{
acc += a[a_offset(im, ik)] * b[b_offset(ik, in)];
}
acc *= alpha;
c[c_offset(im, in)] = acc;
}
}
}
template <typename data_type, typename ALayout, typename BLayout, typename ukenrel_t>
void test_ukernel(ukenrel_t uk,
data_type* mat_a,
data_type* mat_b,
float* mat_c,
float alpha,
uint32_t m,
uint32_t n,
uint32_t k)
{
ck::cpu::ThreadwiseGemmParam param;
param.p_a = mat_a;
param.p_b = mat_b;
param.p_c = mat_c;
param.Kr = k;
param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(data_type);
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(data_type);
param.ldc = n * sizeof(float);
param.alpha = alpha;
printf("gemm_uk_%dx%d_%c%c: ", uk.Mr_, uk.Nr_, ALayout::name[0], BLayout::name[0]);
fflush(stdout);
// printf("%s: ", typeid(uk).name());fflush(stdout);
memset(mat_c, 0, m * n * sizeof(float));
int repeat = 7e10 / (2 * m * n * k);
for(int i = 0; i < (repeat / 5); i++)
{
uk.Run(&param);
}
auto t0 = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; i++)
{
uk.Run(&param);
}
auto t1 = std::chrono::high_resolution_clock::now();
double us = static_cast<double>(
std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count()) /
repeat;
double gflops = static_cast<double>(2 * m * n * k) * 1e-3 / us;
memset(mat_c, 0, m * n * sizeof(float));
uk.Run(&param);
printf("m:%u, n:%u, k:%u, alpha:%f, cost:%lfus, GFLOPS:%lf, ", m, n, k, alpha, us, gflops);
fflush(stdout);
}
// implement small ukernel on L1
template <typename data_type, typename ALayout, typename BLayout>
void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
{
data_type* mat_a =
reinterpret_cast<data_type*>(__aligned_malloc(m * k * sizeof(data_type), 32));
data_type* mat_b =
reinterpret_cast<data_type*>(__aligned_malloc(k * n * sizeof(data_type), 32));
float* mat_c = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32));
float* mat_c_ref = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32));
memset(mat_c_ref, 0, m * n * sizeof(float));
rand_vector(mat_a, m * k);
rand_vector(mat_b, k * n);
ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k);
ck::static_for<0, std::tuple_size_v<thread_gemm_avx2_mxn_6x16_instances>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_avx2_mxn_6x16_instances>;
if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value ||
!std::is_same<typename uk_type::BLayout_, BLayout>::value)
{
return;
}
if(uk_type::Mr_ != m || uk_type::Nr_ != n)
return;
test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k);
bool is_valid = valid_vector(mat_c_ref, mat_c, m * n);
printf("vald:%s\n", is_valid ? "y" : "n");
// return ;
});
__aligned_free(mat_a);
__aligned_free(mat_b);
__aligned_free(mat_c);
__aligned_free(mat_c_ref);
}
int main(int argc, char** argv)
{
int m = 6;
int n = 16;
int k = 64;
float alpha = 1.0f;
if(argc > 3)
{
m = std::atoi(argv[1]);
n = std::atoi(argv[2]);
k = std::atoi(argv[3]);
}
if(argc > 4)
{
alpha = std::atof(argv[4]);
}
dump_cache_hierarchy();
test_cpu_ukernel<float, Row, Row>(alpha, m, n, k);
test_cpu_ukernel<float, Row, Col>(alpha, m, n, k);
test_cpu_ukernel<float, Col, Row>(alpha, m, n, k);
test_cpu_ukernel<float, Col, Col>(alpha, m, n, k);
}
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <string>
#include <sstream>
#include <tuple>
#include <memory>
#include <chrono>
#include "config.hpp"
#include "print.hpp"
#include "cpuid.hpp"
#include "threadwise_gemm_avx2.hpp"
#define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 5, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 4, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 3, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 5, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 4, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 3, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 8, TA, TB, NT>
// #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Row, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Row, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Col, false)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false)
// clang-format on
>;
void dump_cache_hierarchy()
{
auto dump_cache_type = [&](const ck::cpu::cpuid_cache_type& type) {
if(type == ck::cpu::cpuid_cache_type_dcache)
printf("data cache");
else if(type == ck::cpu::cpuid_cache_type_icache)
printf("inst cache");
else if(type == ck::cpu::cpuid_cache_type_unified)
printf("unif cache");
};
auto dump_cache_detail = [&](const ck::cpu::cpuid_cache_detail& detail) {
dump_cache_type(static_cast<const ck::cpu::cpuid_cache_type>(detail.type));
printf(" size:%u, cache_line:%u, associativity:%u, sets:%u, partitions:%u, shared by "
"procs:%u(%u)\n",
detail.size,
detail.cache_line_size,
detail.associativity,
detail.sets,
detail.partitions,
detail.shared_by_procs,
detail.cores_per_socket);
};
ck::cpu::cpuid_cache_hierarchy cache = ck::cpu::cpuid_query_cache();
if(cache.l1d.size != 0)
{
printf("l1 ");
dump_cache_detail(cache.l1d);
}
if(cache.l1i.size != 0)
{
printf("l1 ");
dump_cache_detail(cache.l1i);
}
if(cache.l2.size != 0)
{
printf("l2 ");
dump_cache_detail(cache.l2);
}
if(cache.l3.size != 0)
{
printf("l3 ");
dump_cache_detail(cache.l3);
}
if(cache.l4.size != 0)
{
printf("l4 ");
dump_cache_detail(cache.l4);
}
}
void* __aligned_malloc(size_t required_bytes, size_t alignment)
{
if(alignment == 0 || (alignment & (alignment - 1))) // check pow of 2
return nullptr;
void* p1; // original block
void** p2; // aligned block
int offset = alignment - 1 + sizeof(void*);
if((p1 = malloc(required_bytes + offset)) == nullptr)
{
return nullptr;
}
p2 = reinterpret_cast<void**>((reinterpret_cast<size_t>(p1) + offset) & ~(alignment - 1));
p2[-1] = p1;
return p2;
}
void __aligned_free(void* p) { free((reinterpret_cast<void**>(p))[-1]); }
template <typename T>
void rand_vector(T* v, int elem)
{
int i;
static int flag = 0;
if(!flag)
{
srand(time(nullptr));
flag = 1;
}
for(i = 0; i < elem; i++)
{
v[i] = (static_cast<T>(rand() % 100)) / 100.0f;
}
}
bool valid_vector(const float* ref, const float* rhs, uint32_t elem)
{
float rtol = 1e-5;
float atol = 1e-8;
uint32_t err = 0;
for(uint32_t i = 0; i < elem; i++)
{
float diff = std::abs(ref[i] - rhs[i]);
if(diff > atol + rtol * std::abs(ref[i]))
{
printf("diff at %u, ref:%f, rhs:%f\n", i, ref[i], rhs[i]);
err++;
}
}
return err == 0;
}
template <typename data_type, typename ALayout, typename BLayout>
void ref_cpu_gemm_uk(const data_type* a,
const data_type* b,
float* c,
float alpha,
uint32_t m,
uint32_t n,
uint32_t k)
{
auto a_offset = [&](uint32_t im, uint32_t ik) {
if constexpr(std::is_same<Row, ALayout>::value)
{
return im * k + ik;
}
else
{
return ik * m + im;
}
};
auto b_offset = [&](uint32_t ik, uint32_t in) {
if constexpr(std::is_same<Row, BLayout>::value)
{
return ik * n + in;
}
else
{
// n*k*n8
return (in / 8) * k * 8 + ik * 8 + in % 8;
}
};
auto c_offset = [&](uint32_t im, uint32_t in) { return im * n + in; };
for(uint32_t im = 0; im < m; im++)
{
for(uint32_t in = 0; in < n; in++)
{
float acc = .0f;
for(uint32_t ik = 0; ik < k; ik++)
{
acc += a[a_offset(im, ik)] * b[b_offset(ik, in)];
}
acc *= alpha;
c[c_offset(im, in)] = acc;
}
}
}
template <typename data_type, typename ALayout, typename BLayout, typename ukenrel_t>
void test_ukernel(ukenrel_t uk,
data_type* mat_a,
data_type* mat_b,
float* mat_c,
float alpha,
uint32_t m,
uint32_t n,
uint32_t k)
{
ck::cpu::ThreadwiseGemmParam param;
param.p_a = mat_a;
param.p_b = mat_b;
param.p_c = mat_c;
param.Kr = k;
param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(data_type);
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(data_type);
param.ldc = n * sizeof(float);
param.alpha = alpha;
auto invoke_uk = [&]() {
if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Row, BLayout>::value)
{
assert(m % uk.Mr_ == 0 && n == uk.Nr_);
data_type* p_a = mat_a;
float* p_c = mat_c;
param.p_a = p_a;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_)
{
uk.Run(&param);
p_a += uk.Mr_ * k;
p_c += uk.Mr_ * n;
param.p_a = p_a;
param.p_c = p_c;
}
}
else if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Col, BLayout>::value)
{
assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
data_type* p_a = mat_a;
// data_type* p_b = mat_b;
float* p_c = mat_c;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_)
{
float* p_c_n = p_c;
float* p_b_n = mat_b;
for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_)
{
uk.Run(&param);
p_b_n += uk.Nr_ * k; // Nr_/8*k*8
p_c_n += uk.Nr_;
param.p_b = p_b_n;
param.p_c = p_c_n;
}
p_a += uk.Mr_ * k;
p_c += uk.Mr_ * n;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
}
}
else if constexpr(std::is_same<Col, ALayout>::value && std::is_same<Row, BLayout>::value)
{
assert(m == uk.Mr_ && n == uk.Nr_);
uk.Run(&param);
}
else
{
assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
data_type* p_b = mat_b;
float* p_c = mat_c;
param.p_b = p_b;
param.p_c = p_c;
for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_)
{
uk.Run(&param);
p_b += uk.Nr_ * k; // Nr_/8*k*8
p_c += uk.Nr_;
param.p_b = p_b;
param.p_c = p_c;
}
}
};
printf("gemm_uk_%dx%d_%c%c: ", uk.Mr_, uk.Nr_, ALayout::name[0], BLayout::name[0]);
fflush(stdout);
// printf("%s: ", typeid(uk).name());fflush(stdout);
memset(mat_c, 0, m * n * sizeof(float));
int repeat = 7e10 / (2 * m * n * k);
for(int i = 0; i < (repeat / 5); i++)
{
invoke_uk();
}
auto t0 = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; i++)
{
invoke_uk();
}
auto t1 = std::chrono::high_resolution_clock::now();
double us = static_cast<double>(
std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count()) /
repeat;
double gflops = static_cast<double>(2 * m * n * k) * 1e-3 / us;
memset(mat_c, 0, m * n * sizeof(float));
invoke_uk();
printf("m:%u, n:%u, k:%u, alpha:%f, cost:%lfus, GFLOPS:%lf, ", m, n, k, alpha, us, gflops);
fflush(stdout);
}
// implement small ukernel on L1
template <typename data_type, typename ALayout, typename BLayout>
void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
{
data_type* mat_a =
reinterpret_cast<data_type*>(__aligned_malloc(m * k * sizeof(data_type), 32));
data_type* mat_b =
reinterpret_cast<data_type*>(__aligned_malloc(k * n * sizeof(data_type), 32));
float* mat_c = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32));
float* mat_c_ref = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32));
memset(mat_c_ref, 0, m * n * sizeof(float));
rand_vector(mat_a, m * k);
rand_vector(mat_b, k * n);
ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k);
ck::static_for<0, std::tuple_size_v<thread_gemm_avx2_mxn_6x16_instances>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_avx2_mxn_6x16_instances>;
if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value ||
!std::is_same<typename uk_type::BLayout_, BLayout>::value)
{
return;
}
if(m % uk_type::Mr_ != 0 || n % uk_type::Nr_ != 0)
return;
if((m != uk_type::Mr_ && std::is_same<typename uk_type::ALayout_, Col>::value) ||
(n != uk_type::Nr_ && std::is_same<typename uk_type::BLayout_, Row>::value))
// only k is the fast changing dim of A/B can we do muldiplt m, n
return;
test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k);
bool is_valid = valid_vector(mat_c_ref, mat_c, m * n);
printf("vald:%s\n", is_valid ? "y" : "n");
// return ;
});
__aligned_free(mat_a);
__aligned_free(mat_b);
__aligned_free(mat_c);
__aligned_free(mat_c_ref);
}
int main(int argc, char** argv)
{
int m = 6;
int n = 16;
int k = 64;
float alpha = 1.0f;
if(argc > 3)
{
m = std::atoi(argv[1]);
n = std::atoi(argv[2]);
k = std::atoi(argv[3]);
}
if(argc > 4)
{
alpha = std::atof(argv[4]);
}
dump_cache_hierarchy();
test_cpu_ukernel<float, Row, Row>(alpha, m, n, k);
test_cpu_ukernel<float, Row, Col>(alpha, m, n, k);
test_cpu_ukernel<float, Col, Row>(alpha, m, n, k);
test_cpu_ukernel<float, Col, Col>(alpha, m, n, k);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment