"examples/community/run_tensorrt_controlnet.py" did not exist on "7447f75b9f8badb073636ed163417b0947c59e9f"
Commit 35f95fe9 authored by carlushuang's avatar carlushuang
Browse files

movaps->movups, and support loop over L1

parent e72c0c43
#ifndef CK_THREADWISE_GEMM_AVX2_HPP #ifndef CK_THREADWISE_GEMM_AVX2_HPP
#define CK_THREADWISE_GEMM_AVX2_HPP #define CK_THREADWISE_GEMM_AVX2_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "math.hpp" #include "math.hpp"
#include "threadwise_param.hpp" #include "threadwise_param.hpp"
namespace ck { namespace ck {
namespace cpu { namespace cpu {
template <typename FloatA, template <typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
index_t Mr, index_t Mr,
index_t Nr, index_t Nr,
typename ALayout, // default is k*m, trans->m*k typename ALayout, // default is k*m, trans->m*k
typename BLayout, // default is n/8*k*n8, trans->k*n typename BLayout, // default is n/8*k*n8, trans->k*n
bool NonTemporalStore> bool NonTemporalStore>
struct ThreadwiseGemmAvx2_MxN_6x16 struct ThreadwiseGemmAvx2_MxN_6x16
{ {
using ALayout_ = ALayout; using ALayout_ = ALayout;
using BLayout_ = BLayout; using BLayout_ = BLayout;
static constexpr auto Mr_ = Mr; static constexpr auto Mr_ = Mr;
static constexpr auto Nr_ = Nr; static constexpr auto Nr_ = Nr;
static constexpr auto NonTemporalStore_ = NonTemporalStore; static constexpr auto NonTemporalStore_ = NonTemporalStore;
__host__ constexpr ThreadwiseGemmAvx2_MxN_6x16() __host__ constexpr ThreadwiseGemmAvx2_MxN_6x16()
{ {
static_assert(Mr <= 6 && Mr >= 1 && (Nr == 8 || Nr == 16), "wrong! Mr x Nr not valid"); static_assert(Mr <= 6 && Mr >= 1 && (Nr == 8 || Nr == 16), "wrong! Mr x Nr not valid");
} }
__host__ static void Run(ThreadwiseGemmParam* param) __host__ static void Run(ThreadwiseGemmParam* param)
{ {
/* 6x16 ukernel /* 6x16 ukernel
* *
* Mat_B * Mat_B
* |ymm12 |ymm13 | * |ymm12 |ymm13 |
* Mat_A +--------+--------+ * Mat_A +--------+--------+
* ymm14 |ymm0 |ymm1 | cycle 0 * ymm14 |ymm0 |ymm1 | cycle 0
* ymm15 |ymm2 |ymm3 | cycle 1 * ymm15 |ymm2 |ymm3 | cycle 1
* ymm14 |ymm4 |ymm5 | cycle 2 * ymm14 |ymm4 |ymm5 | cycle 2
* ymm15 |ymm6 |ymm7 | cycle 3 * ymm15 |ymm6 |ymm7 | cycle 3
* ymm14 |ymm8 |ymm9 | cycle 4 * ymm14 |ymm8 |ymm9 | cycle 4
* ymm15 |ymm10 |ymm11 | Mat_C cycle 5 * ymm15 |ymm10 |ymm11 | Mat_C cycle 5
* *
* ALayout:ColumnMajor (k*m), lda not needed * ALayout:ColumnMajor (k*m), lda not needed
* ALayout:RowMajor (m*k), lda = k * ALayout:RowMajor (m*k), lda = k
* BLayout:ColumnMajor (n/8*k*n8), ldb = k*n8. At least this should be 8 continuous n for a * BLayout:ColumnMajor (n/8*k*n8), ldb = k*n8. At least this should be 8 continuous n for a
* ymm register BLayout:RowMajor (k*n), ldb not needed * ymm register BLayout:RowMajor (k*n), ldb not needed
* *
* lda/ldb/ldc all in unit of byte * lda/ldb/ldc all in unit of byte
* *
*/ */
// clang-format off // clang-format off
__asm__ __volatile__ ( __asm__ __volatile__ (
"L_GemmAvx2_MxN_6x16_Entry%=:\n" "L_GemmAvx2_MxN_6x16_Entry%=:\n"
".set m_Mr, %c[m_Mr]\n" ".set m_Mr, %c[m_Mr]\n"
".set m_Nr, %c[m_Nr]\n" ".set m_Nr, %c[m_Nr]\n"
".set m_TransA, %c[m_TransA]\n" ".set m_TransA, %c[m_TransA]\n"
".set m_TransB, %c[m_TransB]\n" ".set m_TransB, %c[m_TransB]\n"
".set m_NTStore, %c[m_NTStore]\n" ".set m_NTStore, %c[m_NTStore]\n"
".set m_ABytes, %c[m_ABytes]\n" ".set m_ABytes, %c[m_ABytes]\n"
".set m_BBytes, %c[m_BBytes]\n" ".set m_BBytes, %c[m_BBytes]\n"
".set m_CBytes, %c[m_CBytes]\n" ".set m_CBytes, %c[m_CBytes]\n"
"movq (%[m_param]), %%rax\n" // p_a "movq (%[m_param]), %%rax\n" // p_a
"movq 8(%[m_param]), %%rbx\n" // p_b "movq 8(%[m_param]), %%rbx\n" // p_b
"movq 24(%[m_param]), %%rsi\n" // Kr "movq 24(%[m_param]), %%rsi\n" // Kr
".if m_TransA != 0\n" ".if m_TransA != 0\n"
"movq 32(%[m_param]), %%rcx\n" // lda "movq 32(%[m_param]), %%rcx\n" // lda
".endif\n" ".endif\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"movq 40(%[m_param]), %%rdx\n" // ldb "movq 40(%[m_param]), %%rdx\n" // ldb
".endif\n" ".endif\n"
".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n" ".macro vbroadcastss_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n" ".if \\i_scale != 0\n"
"vbroadcastss \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n" "vbroadcastss \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n" ".else\n"
"vbroadcastss \\i_offset(\\r_base), \\ymm\n" "vbroadcastss \\i_offset(\\r_base), \\ymm\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vmovaps_%= r_base, r_stride, i_scale, i_offset, ymm\n" ".macro vmovups_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n" ".if \\i_scale != 0\n"
"vmovaps \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n" "vmovups \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n" ".else\n"
"vmovaps \\i_offset(\\r_base), \\ymm\n" "vmovups \\i_offset(\\r_base), \\ymm\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8, r9), lda in rcx ".macro vbroadcast_a%= i_k, i_m, ymm\n" // A in rax(r8, r9), lda in rcx
".if m_TransA == 0\n" ".if m_TransA == 0\n"
"vbroadcastss_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * 4, \\ymm\n" "vbroadcastss_%= %%rax, 0, 0, (\\i_m + \\i_k * m_Mr) * 4, \\ymm\n"
".else\n" ".else\n"
".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n" ".if (\\i_m == 0) || (\\i_m == 1) || (\\i_m == 2)\n"
"vbroadcastss_%= %%rax, %%rcx, \\i_m, \\i_k * 4, \\ymm\n" "vbroadcastss_%= %%rax, %%rcx, \\i_m, \\i_k * 4, \\ymm\n"
".else\n" ".else\n"
"vbroadcastss_%= %%r8, %%rcx, \\i_m-3, \\i_k * 4, \\ymm\n" "vbroadcastss_%= %%r8, %%rcx, \\i_m-3, \\i_k * 4, \\ymm\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1 ".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4\n" ".if m_BBytes == 4\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovaps_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n" ".else\n"
"vmovaps_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n" "vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovaps_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n" ".else\n"
"vmovaps_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n" "vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
" vxorps %%ymm0, %%ymm0, %%ymm0 \n" " vxorps %%ymm0, %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vxorps %%ymm1, %%ymm1, %%ymm1 \n .endif\n" ".if (m_Nr > 8)\n vxorps %%ymm1, %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vxorps %%ymm2, %%ymm2, %%ymm2 \n .endif\n" ".if (m_Mr > 1) \n vxorps %%ymm2, %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vxorps %%ymm3, %%ymm3, %%ymm3 \n .endif\n" ".if (m_Mr > 1) && (m_Nr > 8)\n vxorps %%ymm3, %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 2) \n vxorps %%ymm4, %%ymm4, %%ymm4 \n .endif\n" ".if (m_Mr > 2) \n vxorps %%ymm4, %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vxorps %%ymm5, %%ymm5, %%ymm5 \n .endif\n" ".if (m_Mr > 2) && (m_Nr > 8)\n vxorps %%ymm5, %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 3) \n vxorps %%ymm6, %%ymm6, %%ymm6 \n .endif\n" ".if (m_Mr > 3) \n vxorps %%ymm6, %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vxorps %%ymm7, %%ymm7, %%ymm7 \n .endif\n" ".if (m_Mr > 3) && (m_Nr > 8)\n vxorps %%ymm7, %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 4) \n vxorps %%ymm8, %%ymm8, %%ymm8 \n .endif\n" ".if (m_Mr > 4) \n vxorps %%ymm8, %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vxorps %%ymm9, %%ymm9, %%ymm9 \n .endif\n" ".if (m_Mr > 4) && (m_Nr > 8)\n vxorps %%ymm9, %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 5) \n vxorps %%ymm10, %%ymm10, %%ymm10\n .endif\n" ".if (m_Mr > 5) \n vxorps %%ymm10, %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vxorps %%ymm11, %%ymm11, %%ymm11\n .endif\n" ".if (m_Mr > 5) && (m_Nr > 8)\n vxorps %%ymm11, %%ymm11, %%ymm11\n .endif\n"
".if m_TransA != 0\n" ".if m_TransA != 0\n"
".if m_Mr > 3\n" ".if m_Mr > 3\n"
"lea (%%rcx, %%rcx, 2), %%r9\n" "lea (%%rcx, %%rcx, 2), %%r9\n"
"lea (%%rax, %%r9), %%r8\n" "lea (%%rax, %%r9), %%r8\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
"cmp $4, %%rsi\n" "cmp $4, %%rsi\n"
"jl L_GemmAvx2_MxN_6x16_K_Loop_Remain%=\n" "jl L_GemmAvx2_MxN_6x16_K_Loop_Remain%=\n"
"L_GemmAvx2_MxN_6x16_K_Loop_Start%=:\n" "L_GemmAvx2_MxN_6x16_K_Loop_Start%=:\n"
".irp i_k, 0, 1, 2, 3\n" ".irp i_k, 0, 1, 2, 3\n"
" vload_b%= \\i_k, 0, %%ymm12\n" // B " vload_b%= \\i_k, 0, %%ymm12\n" // B
".if (m_Nr > 8)\n vload_b%= \\i_k, 1, %%ymm13\n .endif\n" // B ".if (m_Nr > 8)\n vload_b%= \\i_k, 1, %%ymm13\n .endif\n" // B
" vbroadcast_a%= \\i_k, 0, %%ymm14\n" // A broadcast 0 " vbroadcast_a%= \\i_k, 0, %%ymm14\n" // A broadcast 0
".if (m_Mr > 1) \n vbroadcast_a%= \\i_k, 1, %%ymm15\n .endif\n" // A broadcast 1 ".if (m_Mr > 1) \n vbroadcast_a%= \\i_k, 1, %%ymm15\n .endif\n" // A broadcast 1
" vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" // 0x0 " vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" // 0x0
".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n .endif\n" // 0x1 ".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n .endif\n" // 0x1
".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n .endif\n" // 1x0 ".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n .endif\n" // 1x1 ".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n .endif\n" // 1x1
".if (m_Mr > 2) \n vbroadcast_a%= \\i_k, 2, %%ymm14\n .endif\n" // A broadcast 2 ".if (m_Mr > 2) \n vbroadcast_a%= \\i_k, 2, %%ymm14\n .endif\n" // A broadcast 2
".if (m_Mr > 3) \n vbroadcast_a%= \\i_k, 3, %%ymm15\n .endif\n" // A broadcast 3 ".if (m_Mr > 3) \n vbroadcast_a%= \\i_k, 3, %%ymm15\n .endif\n" // A broadcast 3
".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n .endif\n" // 2x0 ".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n .endif\n" // 2x1 ".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n .endif\n" // 2x1
".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 3x0 ".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 3x1 ".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 3x1
".if (m_Mr > 4) \n vbroadcast_a%= \\i_k, 4, %%ymm14\n .endif\n" // A broadcast 4 ".if (m_Mr > 4) \n vbroadcast_a%= \\i_k, 4, %%ymm14\n .endif\n" // A broadcast 4
".if (m_Mr > 5) \n vbroadcast_a%= \\i_k, 5, %%ymm15\n .endif\n" // A broadcast 5 ".if (m_Mr > 5) \n vbroadcast_a%= \\i_k, 5, %%ymm15\n .endif\n" // A broadcast 5
".if (m_Mr > 4) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n .endif\n" // 4x0 ".if (m_Mr > 4) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n .endif\n" // 4x0
".if (m_Mr > 4) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n .endif\n" // 4x1 ".if (m_Mr > 4) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n .endif\n" // 4x1
".if (m_Mr > 5) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n .endif\n" // 5x0 ".if (m_Mr > 5) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n .endif\n" // 5x0
".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1 ".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1
".endr\n" ".endr\n"
".if m_TransA != 0\n" ".if m_TransA != 0\n"
" lea 4*4(%%rax), %%rax\n" " lea 4*4(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4*4(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea 4*4(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4 * 4(%%rax), %%rax\n" " lea m_Mr * 4 * 4(%%rax), %%rax\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4 * 4(%%rbx), %%rbx\n" " lea m_Nr * 4 * 4(%%rbx), %%rbx\n"
".else\n" ".else\n"
" lea 8 * 4 * 4(%%rbx), %%rbx\n" " lea 8 * 4 * 4(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
"sub $4, %%rsi\n" "sub $4, %%rsi\n"
"cmp $4, %%rsi\n" "cmp $4, %%rsi\n"
"jge L_GemmAvx2_MxN_6x16_K_Loop_Start%=\n" "jge L_GemmAvx2_MxN_6x16_K_Loop_Start%=\n"
"testq %%rsi, %%rsi\n" "testq %%rsi, %%rsi\n"
"je L_GemmAvx2_MxN_6x16_K_Loop_End%=\n" "je L_GemmAvx2_MxN_6x16_K_Loop_End%=\n"
"L_GemmAvx2_MxN_6x16_K_Loop_Remain%=:\n" "L_GemmAvx2_MxN_6x16_K_Loop_Remain%=:\n"
" vload_b%= 0, 0, %%ymm12\n" // B " vload_b%= 0, 0, %%ymm12\n" // B
".if (m_Nr > 8)\n vload_b%= 0, 1, %%ymm13\n .endif\n" // B ".if (m_Nr > 8)\n vload_b%= 0, 1, %%ymm13\n .endif\n" // B
" vbroadcast_a%= 0, 0, %%ymm14\n" // A broadcast 0 " vbroadcast_a%= 0, 0, %%ymm14\n" // A broadcast 0
".if (m_Mr > 1) \n vbroadcast_a%= 0, 1, %%ymm15\n .endif\n" // A broadcast 1 ".if (m_Mr > 1) \n vbroadcast_a%= 0, 1, %%ymm15\n .endif\n" // A broadcast 1
" vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" // 0x0 " vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" // 0x0
".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n .endif\n" // 0x1 ".if (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n .endif\n" // 0x1
".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n .endif\n" // 1x0 ".if (m_Mr > 1) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n .endif\n" // 1x1 ".if (m_Mr > 1) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n .endif\n" // 1x1
".if (m_Mr > 2) \n vbroadcast_a%= 0, 2, %%ymm14\n .endif\n" // A broadcast 2 ".if (m_Mr > 2) \n vbroadcast_a%= 0, 2, %%ymm14\n .endif\n" // A broadcast 2
".if (m_Mr > 3) \n vbroadcast_a%= 0, 3, %%ymm15\n .endif\n" // A broadcast 3 ".if (m_Mr > 3) \n vbroadcast_a%= 0, 3, %%ymm15\n .endif\n" // A broadcast 3
".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n .endif\n" // 2x0 ".if (m_Mr > 2) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n .endif\n" // 2x1 ".if (m_Mr > 2) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n .endif\n" // 2x1
".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 3x0 ".if (m_Mr > 3) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 3x1 ".if (m_Mr > 3) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n .endif\n" // 3x1
".if (m_Mr > 4) \n vbroadcast_a%= 0, 4, %%ymm14\n .endif\n" // A broadcast 4 ".if (m_Mr > 4) \n vbroadcast_a%= 0, 4, %%ymm14\n .endif\n" // A broadcast 4
".if (m_Mr > 5) \n vbroadcast_a%= 0, 5, %%ymm15\n .endif\n" // A broadcast 5 ".if (m_Mr > 5) \n vbroadcast_a%= 0, 5, %%ymm15\n .endif\n" // A broadcast 5
".if (m_Mr > 4) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n .endif\n" // 4x0 ".if (m_Mr > 4) \n vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n .endif\n" // 4x0
".if (m_Mr > 4) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n .endif\n" // 4x1 ".if (m_Mr > 4) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n .endif\n" // 4x1
".if (m_Mr > 5) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n .endif\n" // 5x0 ".if (m_Mr > 5) \n vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n .endif\n" // 5x0
".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1 ".if (m_Mr > 5) && (m_Nr > 8)\n vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n .endif\n" // 5x1
".if m_TransA != 0\n" ".if m_TransA != 0\n"
" lea 4(%%rax), %%rax\n" " lea 4(%%rax), %%rax\n"
".if m_Mr > 3\n lea 4(%%r8), %%r8\n .endif\n" ".if m_Mr > 3\n lea 4(%%r8), %%r8\n .endif\n"
".else\n" ".else\n"
" lea m_Mr * 4(%%rax), %%rax\n" " lea m_Mr * 4(%%rax), %%rax\n"
".endif\n" ".endif\n"
".if m_TransB != 0\n" ".if m_TransB != 0\n"
" lea m_Nr * 4(%%rbx), %%rbx\n" " lea m_Nr * 4(%%rbx), %%rbx\n"
".else\n" ".else\n"
" lea 8*4(%%rbx), %%rbx\n" " lea 8*4(%%rbx), %%rbx\n"
".endif\n" ".endif\n"
"sub $1, %%rsi\n" "sub $1, %%rsi\n"
"jne L_GemmAvx2_MxN_6x16_K_Loop_Remain%=\n" "jne L_GemmAvx2_MxN_6x16_K_Loop_Remain%=\n"
"L_GemmAvx2_MxN_6x16_K_Loop_End%=:\n" "L_GemmAvx2_MxN_6x16_K_Loop_End%=:\n"
"mov 56(%[m_param]), %%eax\n" // alpha "mov 56(%[m_param]), %%eax\n" // alpha
"cmp $0x3f800000, %%eax\n" "cmp $0x3f800000, %%eax\n"
"je L_GemmAvx2_MxN_6x16_Update_C%=\n" "je L_GemmAvx2_MxN_6x16_Update_C%=\n"
"vbroadcastss 56(%[m_param]), %%ymm12\n" "vbroadcastss 56(%[m_param]), %%ymm12\n"
" vmulps %%ymm12, %%ymm0, %%ymm0 \n" // 0x0 " vmulps %%ymm12, %%ymm0, %%ymm0 \n" // 0x0
".if (m_Nr > 8)\n vmulps %%ymm12, %%ymm1, %%ymm1 \n .endif\n" // 0x1 ".if (m_Nr > 8)\n vmulps %%ymm12, %%ymm1, %%ymm1 \n .endif\n" // 0x1
".if (m_Mr > 1) \n vmulps %%ymm12, %%ymm2, %%ymm2 \n .endif\n" // 1x0 ".if (m_Mr > 1) \n vmulps %%ymm12, %%ymm2, %%ymm2 \n .endif\n" // 1x0
".if (m_Mr > 1) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm3, %%ymm3 \n .endif\n" // 1x1 ".if (m_Mr > 1) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm3, %%ymm3 \n .endif\n" // 1x1
".if (m_Mr > 2) \n vmulps %%ymm12, %%ymm4, %%ymm4 \n .endif\n" // 2x0 ".if (m_Mr > 2) \n vmulps %%ymm12, %%ymm4, %%ymm4 \n .endif\n" // 2x0
".if (m_Mr > 2) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm5, %%ymm5 \n .endif\n" // 2x1 ".if (m_Mr > 2) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm5, %%ymm5 \n .endif\n" // 2x1
".if (m_Mr > 3) \n vmulps %%ymm12, %%ymm6, %%ymm6 \n .endif\n" // 3x0 ".if (m_Mr > 3) \n vmulps %%ymm12, %%ymm6, %%ymm6 \n .endif\n" // 3x0
".if (m_Mr > 3) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm7, %%ymm7 \n .endif\n" // 3x1 ".if (m_Mr > 3) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm7, %%ymm7 \n .endif\n" // 3x1
".if (m_Mr > 4) \n vmulps %%ymm12, %%ymm8, %%ymm8 \n .endif\n" // 4x0 ".if (m_Mr > 4) \n vmulps %%ymm12, %%ymm8, %%ymm8 \n .endif\n" // 4x0
".if (m_Mr > 4) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm9, %%ymm9 \n .endif\n" // 4x1 ".if (m_Mr > 4) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm9, %%ymm9 \n .endif\n" // 4x1
".if (m_Mr > 5) \n vmulps %%ymm12, %%ymm10, %%ymm10\n .endif\n" // 5x0 ".if (m_Mr > 5) \n vmulps %%ymm12, %%ymm10, %%ymm10\n .endif\n" // 5x0
".if (m_Mr > 5) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm11, %%ymm11\n .endif\n" // 5x1 ".if (m_Mr > 5) && (m_Nr > 8)\n vmulps %%ymm12, %%ymm11, %%ymm11\n .endif\n" // 5x1
"L_GemmAvx2_MxN_6x16_Update_C%=:\n" "L_GemmAvx2_MxN_6x16_Update_C%=:\n"
"movq 16(%[m_param]), %%rax\n" // p_c "movq 16(%[m_param]), %%rax\n" // p_c
"movq 48(%[m_param]), %%rdi\n" // ldc "movq 48(%[m_param]), %%rdi\n" // ldc
".if (m_Mr > 1)\n lea (%%rax, %%rdi, 1), %%rbx\n .endif\n" ".if (m_Mr > 1)\n lea (%%rax, %%rdi, 1), %%rbx\n .endif\n"
".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n" ".if (m_Mr > 2)\n lea (%%rbx, %%rdi, 1), %%rcx\n .endif\n"
".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n" ".if (m_Mr > 3)\n lea (%%rcx, %%rdi, 1), %%rdx\n .endif\n"
".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n" ".if (m_Mr > 4)\n lea (%%rdx, %%rdi, 1), %%r8 \n .endif\n"
".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n" ".if (m_Mr > 5)\n lea (%%r8, %%rdi, 1), %%r9 \n .endif\n"
" vaddps (%%rax), %%ymm0, %%ymm0 \n" " vaddps (%%rax), %%ymm0, %%ymm0 \n"
".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n" ".if (m_Nr > 8)\n vaddps 32(%%rax), %%ymm1, %%ymm1 \n .endif\n"
".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n" ".if (m_Mr > 1) \n vaddps (%%rbx), %%ymm2, %%ymm2 \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm3, %%ymm3 \n .endif\n" ".if (m_Mr > 1) && (m_Nr > 8)\n vaddps 32(%%rbx), %%ymm3, %%ymm3 \n .endif\n"
".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm4, %%ymm4 \n .endif\n" ".if (m_Mr > 2) \n vaddps (%%rcx), %%ymm4, %%ymm4 \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm5, %%ymm5 \n .endif\n" ".if (m_Mr > 2) && (m_Nr > 8)\n vaddps 32(%%rcx), %%ymm5, %%ymm5 \n .endif\n"
".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm6, %%ymm6 \n .endif\n" ".if (m_Mr > 3) \n vaddps (%%rdx), %%ymm6, %%ymm6 \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm7, %%ymm7 \n .endif\n" ".if (m_Mr > 3) && (m_Nr > 8)\n vaddps 32(%%rdx), %%ymm7, %%ymm7 \n .endif\n"
".if (m_Mr > 4) \n vaddps (%%r8), %%ymm8, %%ymm8 \n .endif\n" ".if (m_Mr > 4) \n vaddps (%%r8), %%ymm8, %%ymm8 \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vaddps 32(%%r8), %%ymm9, %%ymm9 \n .endif\n" ".if (m_Mr > 4) && (m_Nr > 8)\n vaddps 32(%%r8), %%ymm9, %%ymm9 \n .endif\n"
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n" ".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n" ".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
" vmovaps %%ymm0, (%%rax) \n" " vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovaps %%ymm1, 32(%%rax)\n .endif\n" ".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovaps %%ymm2, (%%rbx) \n .endif\n" ".if (m_Mr > 1) \n vmovups %%ymm2, (%%rbx) \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vmovaps %%ymm3, 32(%%rbx)\n .endif\n" ".if (m_Mr > 1) && (m_Nr > 8)\n vmovups %%ymm3, 32(%%rbx)\n .endif\n"
".if (m_Mr > 2) \n vmovaps %%ymm4, (%%rcx) \n .endif\n" ".if (m_Mr > 2) \n vmovups %%ymm4, (%%rcx) \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vmovaps %%ymm5, 32(%%rcx)\n .endif\n" ".if (m_Mr > 2) && (m_Nr > 8)\n vmovups %%ymm5, 32(%%rcx)\n .endif\n"
".if (m_Mr > 3) \n vmovaps %%ymm6, (%%rdx) \n .endif\n" ".if (m_Mr > 3) \n vmovups %%ymm6, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovaps %%ymm7, 32(%%rdx)\n .endif\n" ".if (m_Mr > 3) && (m_Nr > 8)\n vmovups %%ymm7, 32(%%rdx)\n .endif\n"
".if (m_Mr > 4) \n vmovaps %%ymm8, (%%r8) \n .endif\n" ".if (m_Mr > 4) \n vmovups %%ymm8, (%%r8) \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vmovaps %%ymm9, 32(%%r8) \n .endif\n" ".if (m_Mr > 4) && (m_Nr > 8)\n vmovups %%ymm9, 32(%%r8) \n .endif\n"
".if (m_Mr > 5) \n vmovaps %%ymm10, (%%r9) \n .endif\n" ".if (m_Mr > 5) \n vmovups %%ymm10, (%%r9) \n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vmovaps %%ymm11, 32(%%r9) \n .endif\n" ".if (m_Mr > 5) && (m_Nr > 8)\n vmovups %%ymm11, 32(%%r9) \n .endif\n"
"L_GemmAvx2_MxN_6x16_Exit%=:\n" "L_GemmAvx2_MxN_6x16_Exit%=:\n"
: :
: :
[m_Mr] "i" (Mr), [m_Mr] "i" (Mr),
[m_Nr] "i" (Nr), [m_Nr] "i" (Nr),
[m_TransA] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? 1 : 0), [m_TransA] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? 1 : 0),
[m_TransB] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? 1 : 0), [m_TransB] "i" (std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? 1 : 0),
[m_NTStore] "i" (NonTemporalStore), [m_NTStore] "i" (NonTemporalStore),
[m_ABytes] "i" (sizeof(FloatA)), [m_ABytes] "i" (sizeof(FloatA)),
[m_BBytes] "i" (sizeof(FloatB)), [m_BBytes] "i" (sizeof(FloatB)),
[m_CBytes] "i" (sizeof(FloatC)), [m_CBytes] "i" (sizeof(FloatC)),
[m_param] "r" (param) [m_param] "r" (param)
: :
"cc", "cc",
"rax","rbx","rcx","rdx","rsi","rdi","r8","r9", "rax","rbx","rcx","rdx","rsi","rdi","r8","r9",
"ymm0","ymm1","ymm2","ymm3","ymm4","ymm5","ymm6", "ymm0","ymm1","ymm2","ymm3","ymm4","ymm5","ymm6",
"ymm7","ymm8","ymm9","ymm10","ymm11","ymm12","ymm13", "ymm7","ymm8","ymm9","ymm10","ymm11","ymm12","ymm13",
"ymm14","ymm15" "ymm14","ymm15"
); );
// clang-format on // clang-format on
} }
}; };
} // namespace cpu } // namespace cpu
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_THREADWISE_PARAM_HPP #ifndef CK_THREADWISE_PARAM_HPP
#define CK_THREADWISE_PARAM_HPP #define CK_THREADWISE_PARAM_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "math.hpp" #include "math.hpp"
namespace ck { namespace ck {
namespace cpu { namespace cpu {
struct ThreadwiseGemmParam struct ThreadwiseGemmParam
{ {
const void* p_a; const void* p_a;
const void* p_b; const void* p_b;
void* p_c; void* p_c;
uint64_t Kr; uint64_t Kr;
uint64_t lda; // in unit of byte uint64_t lda; // in unit of byte
uint64_t ldb; // in unit of byte uint64_t ldb; // in unit of byte
uint64_t ldc; // in unit of byte uint64_t ldc; // in unit of byte
float alpha; float alpha;
uint32_t _pack0; uint32_t _pack0;
} __attribute__((packed)); } __attribute__((packed));
} // namespace cpu } // namespace cpu
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_CPUID_HPP #ifndef CK_CPUID_HPP
#define CK_CPUID_HPP #define CK_CPUID_HPP
namespace ck { namespace ck {
namespace cpu { namespace cpu {
enum cpuid_vendor enum cpuid_vendor
{ {
cpuid_vendor_intel = 0, cpuid_vendor_intel = 0,
cpuid_vendor_amd = 1, cpuid_vendor_amd = 1,
cpuid_vendor_other = 2, cpuid_vendor_other = 2,
}; };
enum cpuid_cache_type enum cpuid_cache_type
{ {
cpuid_cache_type_null = 0, cpuid_cache_type_null = 0,
cpuid_cache_type_dcache = 1, cpuid_cache_type_dcache = 1,
cpuid_cache_type_icache = 2, cpuid_cache_type_icache = 2,
cpuid_cache_type_unified = 3, cpuid_cache_type_unified = 3,
}; };
struct cpuid_raw struct cpuid_raw
{ {
uint32_t eax{0}; uint32_t eax{0};
uint32_t ebx{0}; uint32_t ebx{0};
uint32_t ecx{0}; uint32_t ecx{0};
uint32_t edx{0}; uint32_t edx{0};
}; };
struct cpuid_cache_detail struct cpuid_cache_detail
{ {
uint32_t size{0}; uint32_t size{0};
uint32_t type{0}; uint32_t type{0};
uint32_t cache_line_size{0}; uint32_t cache_line_size{0};
uint32_t associativity{0}; uint32_t associativity{0};
uint32_t sets{0}; uint32_t sets{0};
uint32_t partitions{0}; uint32_t partitions{0};
uint32_t shared_by_procs{0}; // in HT, usually maybe 2 threads per core, hence for L1/L2, uint32_t shared_by_procs{0}; // in HT, usually maybe 2 threads per core, hence for L1/L2,
// usually this maybe 2, unless turn of HT // usually this maybe 2, unless turn of HT
uint32_t cores_per_socket{0}; // hardware cores in a physical socket. there maybe multiple uint32_t cores_per_socket{0}; // hardware cores in a physical socket. there maybe multiple
// sockets on the chip. TODO: may not needed? // sockets on the chip. TODO: may not needed?
uint32_t flags{0}; uint32_t flags{0};
}; };
struct cpuid_cache_hierarchy struct cpuid_cache_hierarchy
{ {
cpuid_cache_detail l1i; cpuid_cache_detail l1i;
cpuid_cache_detail l1d; cpuid_cache_detail l1d;
cpuid_cache_detail l2; cpuid_cache_detail l2;
cpuid_cache_detail l3; cpuid_cache_detail l3;
cpuid_cache_detail l4; cpuid_cache_detail l4;
}; };
static inline cpuid_raw cpuid(uint32_t eax, uint32_t ecx) static inline cpuid_raw cpuid(uint32_t eax, uint32_t ecx)
{ {
// some leaf feature require ecx value. // some leaf feature require ecx value.
// for others, ecx actually not used. // for others, ecx actually not used.
uint32_t ebx, edx; uint32_t ebx, edx;
asm __volatile__("mov %0, %%eax\n" asm __volatile__("mov %0, %%eax\n"
"mov %2, %%ecx\n" "mov %2, %%ecx\n"
"cpuid\n" "cpuid\n"
"mov %%eax, %0\n" "mov %%eax, %0\n"
"mov %%ebx, %1\n" "mov %%ebx, %1\n"
"mov %%ecx, %2\n" "mov %%ecx, %2\n"
"mov %%edx, %3\n" "mov %%edx, %3\n"
: "=r"(eax), "=r"(ebx), "=r"(ecx), "=r"(edx) : "=r"(eax), "=r"(ebx), "=r"(ecx), "=r"(edx)
: "0"(eax), "2"(ecx)); : "0"(eax), "2"(ecx));
return {eax, ebx, ecx, edx}; return {eax, ebx, ecx, edx};
} }
static inline cpuid_vendor cpuid_query_vendor() static inline cpuid_vendor cpuid_query_vendor()
{ {
cpuid_raw r = cpuid(0, 0); cpuid_raw r = cpuid(0, 0);
if(r.ebx == 0x756E6547U /*Genu*/ && r.edx == 0x49656E69U /*ineI*/ && if(r.ebx == 0x756E6547U /*Genu*/ && r.edx == 0x49656E69U /*ineI*/ &&
r.ecx == 0x6C65746EU /*ntel*/) r.ecx == 0x6C65746EU /*ntel*/)
{ {
return cpuid_vendor_intel; return cpuid_vendor_intel;
} }
if(r.ebx == 0x68747541U /*Auth*/ && r.edx == 0x74656273U /*enti*/ && if(r.ebx == 0x68747541U /*Auth*/ && r.edx == 0x74656273U /*enti*/ &&
r.ecx == 0x444D4163U /*cAMD*/) r.ecx == 0x444D4163U /*cAMD*/)
{ {
return cpuid_vendor_amd; return cpuid_vendor_amd;
} }
if(r.ebx == 0x69444D41U /*AMDi*/ && r.edx == 0x69746E65U /*sbet*/ && if(r.ebx == 0x69444D41U /*AMDi*/ && r.edx == 0x69746E65U /*sbet*/ &&
r.ecx == 0x21726574U /*ter */) r.ecx == 0x21726574U /*ter */)
{ {
return cpuid_vendor_amd; return cpuid_vendor_amd;
} }
return cpuid_vendor_other; return cpuid_vendor_other;
} }
static inline cpuid_cache_hierarchy cpuid_query_cache() static inline cpuid_cache_hierarchy cpuid_query_cache()
{ {
cpuid_cache_hierarchy cache_hierarchy; cpuid_cache_hierarchy cache_hierarchy;
cpuid_vendor vendor = cpuid_query_vendor(); cpuid_vendor vendor = cpuid_query_vendor();
uint32_t leaf_cache_id = vendor == cpuid_vendor_amd ? 0x8000001d : 0x4; uint32_t leaf_cache_id = vendor == cpuid_vendor_amd ? 0x8000001d : 0x4;
for(uint32_t ecx_idx = 0;; ecx_idx++) for(uint32_t ecx_idx = 0;; ecx_idx++)
{ {
cpuid_raw r = cpuid(leaf_cache_id, ecx_idx); cpuid_raw r = cpuid(leaf_cache_id, ecx_idx);
uint32_t cache_type = r.eax & 0x1f; uint32_t cache_type = r.eax & 0x1f;
if(cache_type == cpuid_cache_type_null) if(cache_type == cpuid_cache_type_null)
break; // Null, no more cache break; // Null, no more cache
uint32_t cache_level = (r.eax >> 5) & 0x7; uint32_t cache_level = (r.eax >> 5) & 0x7;
uint32_t cache_shared_by_cores = 1 + ((r.eax >> 14) & 0xfff); uint32_t cache_shared_by_cores = 1 + ((r.eax >> 14) & 0xfff);
uint32_t cache_lpp_cores = 1 + ((r.eax >> 26) & 0x3f); uint32_t cache_lpp_cores = 1 + ((r.eax >> 26) & 0x3f);
uint32_t cache_line_size = 1 + (r.ebx & 0xfff); uint32_t cache_line_size = 1 + (r.ebx & 0xfff);
uint32_t cache_partitions = 1 + ((r.ebx >> 12) & 0x3ff); uint32_t cache_partitions = 1 + ((r.ebx >> 12) & 0x3ff);
uint32_t cache_associativity = 1 + (r.ebx >> 22); uint32_t cache_associativity = 1 + (r.ebx >> 22);
uint32_t cache_sets = 1 + r.ecx; uint32_t cache_sets = 1 + r.ecx;
switch(cache_level) switch(cache_level)
{ {
case 1: case 1:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified) if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{ {
cache_hierarchy.l1d.size = cache_hierarchy.l1d.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size; cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l1d.type = cache_type; cache_hierarchy.l1d.type = cache_type;
cache_hierarchy.l1d.cache_line_size = cache_line_size; cache_hierarchy.l1d.cache_line_size = cache_line_size;
cache_hierarchy.l1d.associativity = cache_associativity; cache_hierarchy.l1d.associativity = cache_associativity;
cache_hierarchy.l1d.sets = cache_sets; cache_hierarchy.l1d.sets = cache_sets;
cache_hierarchy.l1d.partitions = cache_partitions; cache_hierarchy.l1d.partitions = cache_partitions;
cache_hierarchy.l1d.shared_by_procs = cache_shared_by_cores; cache_hierarchy.l1d.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l1d.cores_per_socket = cache_lpp_cores; cache_hierarchy.l1d.cores_per_socket = cache_lpp_cores;
} }
else if(cache_type == cpuid_cache_type_icache) else if(cache_type == cpuid_cache_type_icache)
{ {
cache_hierarchy.l1i.size = cache_hierarchy.l1i.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size; cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l1i.type = cache_type; cache_hierarchy.l1i.type = cache_type;
cache_hierarchy.l1i.cache_line_size = cache_line_size; cache_hierarchy.l1i.cache_line_size = cache_line_size;
cache_hierarchy.l1i.associativity = cache_associativity; cache_hierarchy.l1i.associativity = cache_associativity;
cache_hierarchy.l1i.sets = cache_sets; cache_hierarchy.l1i.sets = cache_sets;
cache_hierarchy.l1i.partitions = cache_partitions; cache_hierarchy.l1i.partitions = cache_partitions;
cache_hierarchy.l1i.shared_by_procs = cache_shared_by_cores; cache_hierarchy.l1i.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l1i.cores_per_socket = cache_lpp_cores; cache_hierarchy.l1i.cores_per_socket = cache_lpp_cores;
} }
break; break;
case 2: case 2:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified) if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{ {
cache_hierarchy.l2.size = cache_hierarchy.l2.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size; cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l2.type = cache_type; cache_hierarchy.l2.type = cache_type;
cache_hierarchy.l2.cache_line_size = cache_line_size; cache_hierarchy.l2.cache_line_size = cache_line_size;
cache_hierarchy.l2.associativity = cache_associativity; cache_hierarchy.l2.associativity = cache_associativity;
cache_hierarchy.l2.sets = cache_sets; cache_hierarchy.l2.sets = cache_sets;
cache_hierarchy.l2.partitions = cache_partitions; cache_hierarchy.l2.partitions = cache_partitions;
cache_hierarchy.l2.shared_by_procs = cache_shared_by_cores; cache_hierarchy.l2.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l2.cores_per_socket = cache_lpp_cores; cache_hierarchy.l2.cores_per_socket = cache_lpp_cores;
} }
break; break;
case 3: case 3:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified) if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{ {
cache_hierarchy.l3.size = cache_hierarchy.l3.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size; cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l3.type = cache_type; cache_hierarchy.l3.type = cache_type;
cache_hierarchy.l3.cache_line_size = cache_line_size; cache_hierarchy.l3.cache_line_size = cache_line_size;
cache_hierarchy.l3.associativity = cache_associativity; cache_hierarchy.l3.associativity = cache_associativity;
cache_hierarchy.l3.sets = cache_sets; cache_hierarchy.l3.sets = cache_sets;
cache_hierarchy.l3.partitions = cache_partitions; cache_hierarchy.l3.partitions = cache_partitions;
cache_hierarchy.l3.shared_by_procs = cache_shared_by_cores; cache_hierarchy.l3.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l3.cores_per_socket = cache_lpp_cores; cache_hierarchy.l3.cores_per_socket = cache_lpp_cores;
} }
break; break;
case 4: case 4:
if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified) if(cache_type == cpuid_cache_type_dcache || cache_type == cpuid_cache_type_unified)
{ {
cache_hierarchy.l4.size = cache_hierarchy.l4.size =
cache_partitions * cache_sets * cache_associativity * cache_line_size; cache_partitions * cache_sets * cache_associativity * cache_line_size;
cache_hierarchy.l4.type = cache_type; cache_hierarchy.l4.type = cache_type;
cache_hierarchy.l4.cache_line_size = cache_line_size; cache_hierarchy.l4.cache_line_size = cache_line_size;
cache_hierarchy.l4.associativity = cache_associativity; cache_hierarchy.l4.associativity = cache_associativity;
cache_hierarchy.l4.sets = cache_sets; cache_hierarchy.l4.sets = cache_sets;
cache_hierarchy.l4.partitions = cache_partitions; cache_hierarchy.l4.partitions = cache_partitions;
cache_hierarchy.l4.shared_by_procs = cache_shared_by_cores; cache_hierarchy.l4.shared_by_procs = cache_shared_by_cores;
cache_hierarchy.l4.cores_per_socket = cache_lpp_cores; cache_hierarchy.l4.cores_per_socket = cache_lpp_cores;
} }
break; break;
} }
} }
return cache_hierarchy; return cache_hierarchy;
} }
} // namespace cpu } // namespace cpu
} // namespace ck } // namespace ck
#endif #endif
#include <iostream> #include <iostream>
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <tuple> #include <tuple>
#include <memory> #include <memory>
#include <chrono> #include <chrono>
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "cpuid.hpp" #include "cpuid.hpp"
#include "threadwise_gemm_avx2.hpp" #include "threadwise_gemm_avx2.hpp"
#define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \ #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 5, 16, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 5, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 4, 16, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 4, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 3, 16, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 3, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 16, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 16, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 16, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 8, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 5, 8, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 5, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 4, 8, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 4, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 3, 8, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 3, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 8, TA, TB, NT>, \ ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 2, 8, TA, TB, NT>, \
ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 8, TA, TB, NT> ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 1, 8, TA, TB, NT>
// #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \ // #define ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(FA, FB, FC, TA, TB, NT) \
// ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT> // ck::cpu::ThreadwiseGemmAvx2_MxN_6x16<FA, FB, FC, 6, 16, TA, TB, NT>
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using thread_gemm_avx2_mxn_6x16_instances = std::tuple< using thread_gemm_avx2_mxn_6x16_instances = std::tuple<
// clang-format off // clang-format off
// FloatA FloatB FloatC ALayout BLayout NTStore // FloatA FloatB FloatC ALayout BLayout NTStore
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Row, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Row, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Row, false), ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Row, false),
ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Col, false) ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Col, Col, false)
// ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false) // ITERATE_THREAD_GEMM_AVX2_MXN_6X16_INSTANCE(float, float, float, Row, Col, false)
// clang-format on // clang-format on
>; >;
void dump_cache_hierarchy() void dump_cache_hierarchy()
{ {
auto dump_cache_type = [&](const ck::cpu::cpuid_cache_type& type) { auto dump_cache_type = [&](const ck::cpu::cpuid_cache_type& type) {
if(type == ck::cpu::cpuid_cache_type_dcache) if(type == ck::cpu::cpuid_cache_type_dcache)
printf("data cache"); printf("data cache");
else if(type == ck::cpu::cpuid_cache_type_icache) else if(type == ck::cpu::cpuid_cache_type_icache)
printf("inst cache"); printf("inst cache");
else if(type == ck::cpu::cpuid_cache_type_unified) else if(type == ck::cpu::cpuid_cache_type_unified)
printf("unif cache"); printf("unif cache");
}; };
auto dump_cache_detail = [&](const ck::cpu::cpuid_cache_detail& detail) { auto dump_cache_detail = [&](const ck::cpu::cpuid_cache_detail& detail) {
dump_cache_type(static_cast<const ck::cpu::cpuid_cache_type>(detail.type)); dump_cache_type(static_cast<const ck::cpu::cpuid_cache_type>(detail.type));
printf(" size:%u, cache_line:%u, associativity:%u, sets:%u, partitions:%u, shared by " printf(" size:%u, cache_line:%u, associativity:%u, sets:%u, partitions:%u, shared by "
"procs:%u(%u)\n", "procs:%u(%u)\n",
detail.size, detail.size,
detail.cache_line_size, detail.cache_line_size,
detail.associativity, detail.associativity,
detail.sets, detail.sets,
detail.partitions, detail.partitions,
detail.shared_by_procs, detail.shared_by_procs,
detail.cores_per_socket); detail.cores_per_socket);
}; };
ck::cpu::cpuid_cache_hierarchy cache = ck::cpu::cpuid_query_cache(); ck::cpu::cpuid_cache_hierarchy cache = ck::cpu::cpuid_query_cache();
if(cache.l1d.size != 0) if(cache.l1d.size != 0)
{ {
printf("l1 "); printf("l1 ");
dump_cache_detail(cache.l1d); dump_cache_detail(cache.l1d);
} }
if(cache.l1i.size != 0) if(cache.l1i.size != 0)
{ {
printf("l1 "); printf("l1 ");
dump_cache_detail(cache.l1i); dump_cache_detail(cache.l1i);
} }
if(cache.l2.size != 0) if(cache.l2.size != 0)
{ {
printf("l2 "); printf("l2 ");
dump_cache_detail(cache.l2); dump_cache_detail(cache.l2);
} }
if(cache.l3.size != 0) if(cache.l3.size != 0)
{ {
printf("l3 "); printf("l3 ");
dump_cache_detail(cache.l3); dump_cache_detail(cache.l3);
} }
if(cache.l4.size != 0) if(cache.l4.size != 0)
{ {
printf("l4 "); printf("l4 ");
dump_cache_detail(cache.l4); dump_cache_detail(cache.l4);
} }
} }
void* __aligned_malloc(size_t required_bytes, size_t alignment) void* __aligned_malloc(size_t required_bytes, size_t alignment)
{ {
if(alignment == 0 || (alignment & (alignment - 1))) // check pow of 2 if(alignment == 0 || (alignment & (alignment - 1))) // check pow of 2
return nullptr; return nullptr;
void* p1; // original block void* p1; // original block
void** p2; // aligned block void** p2; // aligned block
int offset = alignment - 1 + sizeof(void*); int offset = alignment - 1 + sizeof(void*);
if((p1 = malloc(required_bytes + offset)) == nullptr) if((p1 = malloc(required_bytes + offset)) == nullptr)
{ {
return nullptr; return nullptr;
} }
p2 = reinterpret_cast<void**>((reinterpret_cast<size_t>(p1) + offset) & ~(alignment - 1)); p2 = reinterpret_cast<void**>((reinterpret_cast<size_t>(p1) + offset) & ~(alignment - 1));
p2[-1] = p1; p2[-1] = p1;
return p2; return p2;
} }
void __aligned_free(void* p) { free((reinterpret_cast<void**>(p))[-1]); } void __aligned_free(void* p) { free((reinterpret_cast<void**>(p))[-1]); }
template <typename T> template <typename T>
void rand_vector(T* v, int elem) void rand_vector(T* v, int elem)
{ {
int i; int i;
static int flag = 0; static int flag = 0;
if(!flag) if(!flag)
{ {
srand(time(nullptr)); srand(time(nullptr));
flag = 1; flag = 1;
} }
for(i = 0; i < elem; i++) for(i = 0; i < elem; i++)
{ {
v[i] = (static_cast<T>(rand() % 100)) / 100.0f; v[i] = (static_cast<T>(rand() % 100)) / 100.0f;
} }
} }
bool valid_vector(const float* ref, const float* rhs, uint32_t elem) bool valid_vector(const float* ref, const float* rhs, uint32_t elem)
{ {
float rtol = 1e-5; float rtol = 1e-5;
float atol = 1e-8; float atol = 1e-8;
uint32_t err = 0; uint32_t err = 0;
for(uint32_t i = 0; i < elem; i++) for(uint32_t i = 0; i < elem; i++)
{ {
float diff = std::abs(ref[i] - rhs[i]); float diff = std::abs(ref[i] - rhs[i]);
if(diff > atol + rtol * std::abs(ref[i])) if(diff > atol + rtol * std::abs(ref[i]))
{ {
printf("diff at %u, ref:%f, rhs:%f\n", i, ref[i], rhs[i]); printf("diff at %u, ref:%f, rhs:%f\n", i, ref[i], rhs[i]);
err++; err++;
} }
} }
return err == 0; return err == 0;
} }
template <typename data_type, typename ALayout, typename BLayout> template <typename data_type, typename ALayout, typename BLayout>
void ref_cpu_gemm_uk(const data_type* a, void ref_cpu_gemm_uk(const data_type* a,
const data_type* b, const data_type* b,
float* c, float* c,
float alpha, float alpha,
uint32_t m, uint32_t m,
uint32_t n, uint32_t n,
uint32_t k) uint32_t k)
{ {
auto a_offset = [&](uint32_t im, uint32_t ik) { auto a_offset = [&](uint32_t im, uint32_t ik) {
if constexpr(std::is_same<Row, ALayout>::value) if constexpr(std::is_same<Row, ALayout>::value)
{ {
return im * k + ik; return im * k + ik;
} }
else else
{ {
return ik * m + im; return ik * m + im;
} }
}; };
auto b_offset = [&](uint32_t ik, uint32_t in) { auto b_offset = [&](uint32_t ik, uint32_t in) {
if constexpr(std::is_same<Row, BLayout>::value) if constexpr(std::is_same<Row, BLayout>::value)
{ {
return ik * n + in; return ik * n + in;
} }
else else
{ {
// n*k*n8 // n*k*n8
return (in / 8) * k * 8 + ik * 8 + in % 8; return (in / 8) * k * 8 + ik * 8 + in % 8;
} }
}; };
auto c_offset = [&](uint32_t im, uint32_t in) { return im * n + in; }; auto c_offset = [&](uint32_t im, uint32_t in) { return im * n + in; };
for(uint32_t im = 0; im < m; im++) for(uint32_t im = 0; im < m; im++)
{ {
for(uint32_t in = 0; in < n; in++) for(uint32_t in = 0; in < n; in++)
{ {
float acc = .0f; float acc = .0f;
for(uint32_t ik = 0; ik < k; ik++) for(uint32_t ik = 0; ik < k; ik++)
{ {
acc += a[a_offset(im, ik)] * b[b_offset(ik, in)]; acc += a[a_offset(im, ik)] * b[b_offset(ik, in)];
} }
acc *= alpha; acc *= alpha;
c[c_offset(im, in)] = acc; c[c_offset(im, in)] = acc;
} }
} }
} }
template <typename data_type, typename ALayout, typename BLayout, typename ukenrel_t> template <typename data_type, typename ALayout, typename BLayout, typename ukenrel_t>
void test_ukernel(ukenrel_t uk, void test_ukernel(ukenrel_t uk,
data_type* mat_a, data_type* mat_a,
data_type* mat_b, data_type* mat_b,
float* mat_c, float* mat_c,
float alpha, float alpha,
uint32_t m, uint32_t m,
uint32_t n, uint32_t n,
uint32_t k) uint32_t k)
{ {
ck::cpu::ThreadwiseGemmParam param; ck::cpu::ThreadwiseGemmParam param;
param.p_a = mat_a; param.p_a = mat_a;
param.p_b = mat_b; param.p_b = mat_b;
param.p_c = mat_c; param.p_c = mat_c;
param.Kr = k; param.Kr = k;
param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(data_type); param.lda = (std::is_same<Row, ALayout>::value ? k : m) * sizeof(data_type);
param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(data_type); param.ldb = (std::is_same<Row, BLayout>::value ? n : k * 8) * sizeof(data_type);
param.ldc = n * sizeof(float); param.ldc = n * sizeof(float);
param.alpha = alpha; param.alpha = alpha;
printf("gemm_uk_%dx%d_%c%c: ", uk.Mr_, uk.Nr_, ALayout::name[0], BLayout::name[0]); auto invoke_uk = [&]() {
fflush(stdout); if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Row, BLayout>::value)
// printf("%s: ", typeid(uk).name());fflush(stdout); {
memset(mat_c, 0, m * n * sizeof(float)); assert(m % uk.Mr_ == 0 && n == uk.Nr_);
data_type* p_a = mat_a;
int repeat = 7e10 / (2 * m * n * k); float* p_c = mat_c;
param.p_a = p_a;
for(int i = 0; i < (repeat / 5); i++) param.p_c = p_c;
{ for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_)
uk.Run(&param); {
} uk.Run(&param);
p_a += uk.Mr_ * k;
auto t0 = std::chrono::high_resolution_clock::now(); p_c += uk.Mr_ * n;
for(int i = 0; i < repeat; i++) param.p_a = p_a;
{ param.p_c = p_c;
uk.Run(&param); }
} }
auto t1 = std::chrono::high_resolution_clock::now(); else if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Col, BLayout>::value)
{
double us = static_cast<double>( assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count()) / data_type* p_a = mat_a;
repeat; // data_type* p_b = mat_b;
double gflops = static_cast<double>(2 * m * n * k) * 1e-3 / us; float* p_c = mat_c;
param.p_a = p_a;
memset(mat_c, 0, m * n * sizeof(float)); param.p_b = mat_b;
uk.Run(&param); param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_)
printf("m:%u, n:%u, k:%u, alpha:%f, cost:%lfus, GFLOPS:%lf, ", m, n, k, alpha, us, gflops); {
fflush(stdout); float* p_c_n = p_c;
} float* p_b_n = mat_b;
for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_)
// implement small ukernel on L1 {
template <typename data_type, typename ALayout, typename BLayout> uk.Run(&param);
void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) p_b_n += uk.Nr_ * k; // Nr_/8*k*8
{ p_c_n += uk.Nr_;
data_type* mat_a = param.p_b = p_b_n;
reinterpret_cast<data_type*>(__aligned_malloc(m * k * sizeof(data_type), 32)); param.p_c = p_c_n;
data_type* mat_b = }
reinterpret_cast<data_type*>(__aligned_malloc(k * n * sizeof(data_type), 32)); p_a += uk.Mr_ * k;
float* mat_c = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32)); p_c += uk.Mr_ * n;
param.p_a = p_a;
float* mat_c_ref = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32)); param.p_b = mat_b;
memset(mat_c_ref, 0, m * n * sizeof(float)); param.p_c = p_c;
}
rand_vector(mat_a, m * k); }
rand_vector(mat_b, k * n); else if constexpr(std::is_same<Col, ALayout>::value && std::is_same<Row, BLayout>::value)
{
ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k); assert(m == uk.Mr_ && n == uk.Nr_);
uk.Run(&param);
ck::static_for<0, std::tuple_size_v<thread_gemm_avx2_mxn_6x16_instances>, 1>{}([&](auto i) { }
using uk_type = std::tuple_element_t<i, thread_gemm_avx2_mxn_6x16_instances>; else
if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value || {
!std::is_same<typename uk_type::BLayout_, BLayout>::value) assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
{ data_type* p_b = mat_b;
return; float* p_c = mat_c;
} param.p_b = p_b;
if(uk_type::Mr_ != m || uk_type::Nr_ != n) param.p_c = p_c;
return; for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_)
{
test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k); uk.Run(&param);
p_b += uk.Nr_ * k; // Nr_/8*k*8
bool is_valid = valid_vector(mat_c_ref, mat_c, m * n); p_c += uk.Nr_;
printf("vald:%s\n", is_valid ? "y" : "n"); param.p_b = p_b;
param.p_c = p_c;
// return ; }
}); }
};
__aligned_free(mat_a);
__aligned_free(mat_b); printf("gemm_uk_%dx%d_%c%c: ", uk.Mr_, uk.Nr_, ALayout::name[0], BLayout::name[0]);
__aligned_free(mat_c); fflush(stdout);
__aligned_free(mat_c_ref); // printf("%s: ", typeid(uk).name());fflush(stdout);
} memset(mat_c, 0, m * n * sizeof(float));
int main(int argc, char** argv) int repeat = 7e10 / (2 * m * n * k);
{
int m = 6; for(int i = 0; i < (repeat / 5); i++)
int n = 16; {
int k = 64; invoke_uk();
float alpha = 1.0f; }
if(argc > 3)
{ auto t0 = std::chrono::high_resolution_clock::now();
m = std::atoi(argv[1]); for(int i = 0; i < repeat; i++)
n = std::atoi(argv[2]); {
k = std::atoi(argv[3]); invoke_uk();
} }
if(argc > 4) auto t1 = std::chrono::high_resolution_clock::now();
{
alpha = std::atof(argv[4]); double us = static_cast<double>(
} std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count()) /
dump_cache_hierarchy(); repeat;
test_cpu_ukernel<float, Row, Row>(alpha, m, n, k); double gflops = static_cast<double>(2 * m * n * k) * 1e-3 / us;
test_cpu_ukernel<float, Row, Col>(alpha, m, n, k);
test_cpu_ukernel<float, Col, Row>(alpha, m, n, k); memset(mat_c, 0, m * n * sizeof(float));
test_cpu_ukernel<float, Col, Col>(alpha, m, n, k); invoke_uk();
}
printf("m:%u, n:%u, k:%u, alpha:%f, cost:%lfus, GFLOPS:%lf, ", m, n, k, alpha, us, gflops);
fflush(stdout);
}
// implement small ukernel on L1
template <typename data_type, typename ALayout, typename BLayout>
void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
{
data_type* mat_a =
reinterpret_cast<data_type*>(__aligned_malloc(m * k * sizeof(data_type), 32));
data_type* mat_b =
reinterpret_cast<data_type*>(__aligned_malloc(k * n * sizeof(data_type), 32));
float* mat_c = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32));
float* mat_c_ref = reinterpret_cast<float*>(__aligned_malloc(m * n * sizeof(float), 32));
memset(mat_c_ref, 0, m * n * sizeof(float));
rand_vector(mat_a, m * k);
rand_vector(mat_b, k * n);
ref_cpu_gemm_uk<data_type, ALayout, BLayout>(mat_a, mat_b, mat_c_ref, alpha, m, n, k);
ck::static_for<0, std::tuple_size_v<thread_gemm_avx2_mxn_6x16_instances>, 1>{}([&](auto i) {
using uk_type = std::tuple_element_t<i, thread_gemm_avx2_mxn_6x16_instances>;
if constexpr(!std::is_same<typename uk_type::ALayout_, ALayout>::value ||
!std::is_same<typename uk_type::BLayout_, BLayout>::value)
{
return;
}
if(m % uk_type::Mr_ != 0 || n % uk_type::Nr_ != 0)
return;
if((m != uk_type::Mr_ && std::is_same<typename uk_type::ALayout_, Col>::value) ||
(n != uk_type::Nr_ && std::is_same<typename uk_type::BLayout_, Row>::value))
// only k is the fast changing dim of A/B can we do muldiplt m, n
return;
test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k);
bool is_valid = valid_vector(mat_c_ref, mat_c, m * n);
printf("vald:%s\n", is_valid ? "y" : "n");
// return ;
});
__aligned_free(mat_a);
__aligned_free(mat_b);
__aligned_free(mat_c);
__aligned_free(mat_c_ref);
}
int main(int argc, char** argv)
{
int m = 6;
int n = 16;
int k = 64;
float alpha = 1.0f;
if(argc > 3)
{
m = std::atoi(argv[1]);
n = std::atoi(argv[2]);
k = std::atoi(argv[3]);
}
if(argc > 4)
{
alpha = std::atof(argv[4]);
}
dump_cache_hierarchy();
test_cpu_ukernel<float, Row, Row>(alpha, m, n, k);
test_cpu_ukernel<float, Row, Col>(alpha, m, n, k);
test_cpu_ukernel<float, Col, Row>(alpha, m, n, k);
test_cpu_ukernel<float, Col, Col>(alpha, m, n, k);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment