Commit 35f95fe9 authored by carlushuang's avatar carlushuang
Browse files

movaps->movups, and support loop over L1

parent e72c0c43
...@@ -82,11 +82,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -82,11 +82,11 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".endif\n" ".endif\n"
".endm\n" ".endm\n"
".macro vmovaps_%= r_base, r_stride, i_scale, i_offset, ymm\n" ".macro vmovups_%= r_base, r_stride, i_scale, i_offset, ymm\n"
".if \\i_scale != 0\n" ".if \\i_scale != 0\n"
"vmovaps \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n" "vmovups \\i_offset(\\r_base, \\r_stride, \\i_scale), \\ymm\n"
".else\n" ".else\n"
"vmovaps \\i_offset(\\r_base), \\ymm\n" "vmovups \\i_offset(\\r_base), \\ymm\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
...@@ -105,15 +105,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -105,15 +105,15 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1 ".macro vload_b%= i_k, i_n, ymm\n" // B in rbx, lda in rdx, i_n should be 0, 1
".if m_BBytes == 4\n" ".if m_BBytes == 4\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovaps_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n" ".else\n"
"vmovaps_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n" "vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n" ".endif\n"
".else\n" ".else\n"
".if m_TransB == 0\n" ".if m_TransB == 0\n"
"vmovaps_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n" "vmovups_%= %%rbx, %%rdx, \\i_n, \\i_k*4*8, \\ymm\n"
".else\n" ".else\n"
"vmovaps_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n" "vmovups_%= %%rbx, 0, 0, (\\i_k*m_Nr + \\i_n*8)*4, \\ymm\n"
".endif\n" ".endif\n"
".endif\n" ".endif\n"
".endm\n" ".endm\n"
...@@ -265,18 +265,18 @@ struct ThreadwiseGemmAvx2_MxN_6x16 ...@@ -265,18 +265,18 @@ struct ThreadwiseGemmAvx2_MxN_6x16
".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n" ".if (m_Mr > 5) \n vaddps (%%r9), %%ymm10, %%ymm10\n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n" ".if (m_Mr > 5) && (m_Nr > 8)\n vaddps 32(%%r9), %%ymm11, %%ymm11\n .endif\n"
" vmovaps %%ymm0, (%%rax) \n" " vmovups %%ymm0, (%%rax) \n"
".if (m_Nr > 8)\n vmovaps %%ymm1, 32(%%rax)\n .endif\n" ".if (m_Nr > 8)\n vmovups %%ymm1, 32(%%rax)\n .endif\n"
".if (m_Mr > 1) \n vmovaps %%ymm2, (%%rbx) \n .endif\n" ".if (m_Mr > 1) \n vmovups %%ymm2, (%%rbx) \n .endif\n"
".if (m_Mr > 1) && (m_Nr > 8)\n vmovaps %%ymm3, 32(%%rbx)\n .endif\n" ".if (m_Mr > 1) && (m_Nr > 8)\n vmovups %%ymm3, 32(%%rbx)\n .endif\n"
".if (m_Mr > 2) \n vmovaps %%ymm4, (%%rcx) \n .endif\n" ".if (m_Mr > 2) \n vmovups %%ymm4, (%%rcx) \n .endif\n"
".if (m_Mr > 2) && (m_Nr > 8)\n vmovaps %%ymm5, 32(%%rcx)\n .endif\n" ".if (m_Mr > 2) && (m_Nr > 8)\n vmovups %%ymm5, 32(%%rcx)\n .endif\n"
".if (m_Mr > 3) \n vmovaps %%ymm6, (%%rdx) \n .endif\n" ".if (m_Mr > 3) \n vmovups %%ymm6, (%%rdx) \n .endif\n"
".if (m_Mr > 3) && (m_Nr > 8)\n vmovaps %%ymm7, 32(%%rdx)\n .endif\n" ".if (m_Mr > 3) && (m_Nr > 8)\n vmovups %%ymm7, 32(%%rdx)\n .endif\n"
".if (m_Mr > 4) \n vmovaps %%ymm8, (%%r8) \n .endif\n" ".if (m_Mr > 4) \n vmovups %%ymm8, (%%r8) \n .endif\n"
".if (m_Mr > 4) && (m_Nr > 8)\n vmovaps %%ymm9, 32(%%r8) \n .endif\n" ".if (m_Mr > 4) && (m_Nr > 8)\n vmovups %%ymm9, 32(%%r8) \n .endif\n"
".if (m_Mr > 5) \n vmovaps %%ymm10, (%%r9) \n .endif\n" ".if (m_Mr > 5) \n vmovups %%ymm10, (%%r9) \n .endif\n"
".if (m_Mr > 5) && (m_Nr > 8)\n vmovaps %%ymm11, 32(%%r9) \n .endif\n" ".if (m_Mr > 5) && (m_Nr > 8)\n vmovups %%ymm11, 32(%%r9) \n .endif\n"
"L_GemmAvx2_MxN_6x16_Exit%=:\n" "L_GemmAvx2_MxN_6x16_Exit%=:\n"
: :
: :
......
...@@ -218,6 +218,74 @@ void test_ukernel(ukenrel_t uk, ...@@ -218,6 +218,74 @@ void test_ukernel(ukenrel_t uk,
param.ldc = n * sizeof(float); param.ldc = n * sizeof(float);
param.alpha = alpha; param.alpha = alpha;
auto invoke_uk = [&]() {
if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Row, BLayout>::value)
{
assert(m % uk.Mr_ == 0 && n == uk.Nr_);
data_type* p_a = mat_a;
float* p_c = mat_c;
param.p_a = p_a;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_)
{
uk.Run(&param);
p_a += uk.Mr_ * k;
p_c += uk.Mr_ * n;
param.p_a = p_a;
param.p_c = p_c;
}
}
else if constexpr(std::is_same<Row, ALayout>::value && std::is_same<Col, BLayout>::value)
{
assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
data_type* p_a = mat_a;
// data_type* p_b = mat_b;
float* p_c = mat_c;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
for(uint32_t i_m = 0; i_m < m; i_m += uk.Mr_)
{
float* p_c_n = p_c;
float* p_b_n = mat_b;
for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_)
{
uk.Run(&param);
p_b_n += uk.Nr_ * k; // Nr_/8*k*8
p_c_n += uk.Nr_;
param.p_b = p_b_n;
param.p_c = p_c_n;
}
p_a += uk.Mr_ * k;
p_c += uk.Mr_ * n;
param.p_a = p_a;
param.p_b = mat_b;
param.p_c = p_c;
}
}
else if constexpr(std::is_same<Col, ALayout>::value && std::is_same<Row, BLayout>::value)
{
assert(m == uk.Mr_ && n == uk.Nr_);
uk.Run(&param);
}
else
{
assert(m % uk.Mr_ == 0 && n % uk.Nr_ == 0);
data_type* p_b = mat_b;
float* p_c = mat_c;
param.p_b = p_b;
param.p_c = p_c;
for(uint32_t i_n = 0; i_n < n; i_n += uk.Nr_)
{
uk.Run(&param);
p_b += uk.Nr_ * k; // Nr_/8*k*8
p_c += uk.Nr_;
param.p_b = p_b;
param.p_c = p_c;
}
}
};
printf("gemm_uk_%dx%d_%c%c: ", uk.Mr_, uk.Nr_, ALayout::name[0], BLayout::name[0]); printf("gemm_uk_%dx%d_%c%c: ", uk.Mr_, uk.Nr_, ALayout::name[0], BLayout::name[0]);
fflush(stdout); fflush(stdout);
// printf("%s: ", typeid(uk).name());fflush(stdout); // printf("%s: ", typeid(uk).name());fflush(stdout);
...@@ -227,13 +295,13 @@ void test_ukernel(ukenrel_t uk, ...@@ -227,13 +295,13 @@ void test_ukernel(ukenrel_t uk,
for(int i = 0; i < (repeat / 5); i++) for(int i = 0; i < (repeat / 5); i++)
{ {
uk.Run(&param); invoke_uk();
} }
auto t0 = std::chrono::high_resolution_clock::now(); auto t0 = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; i++) for(int i = 0; i < repeat; i++)
{ {
uk.Run(&param); invoke_uk();
} }
auto t1 = std::chrono::high_resolution_clock::now(); auto t1 = std::chrono::high_resolution_clock::now();
...@@ -243,7 +311,7 @@ void test_ukernel(ukenrel_t uk, ...@@ -243,7 +311,7 @@ void test_ukernel(ukenrel_t uk,
double gflops = static_cast<double>(2 * m * n * k) * 1e-3 / us; double gflops = static_cast<double>(2 * m * n * k) * 1e-3 / us;
memset(mat_c, 0, m * n * sizeof(float)); memset(mat_c, 0, m * n * sizeof(float));
uk.Run(&param); invoke_uk();
printf("m:%u, n:%u, k:%u, alpha:%f, cost:%lfus, GFLOPS:%lf, ", m, n, k, alpha, us, gflops); printf("m:%u, n:%u, k:%u, alpha:%f, cost:%lfus, GFLOPS:%lf, ", m, n, k, alpha, us, gflops);
fflush(stdout); fflush(stdout);
...@@ -274,7 +342,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k) ...@@ -274,7 +342,11 @@ void test_cpu_ukernel(float alpha, uint32_t m, uint32_t n, uint32_t k)
{ {
return; return;
} }
if(uk_type::Mr_ != m || uk_type::Nr_ != n) if(m % uk_type::Mr_ != 0 || n % uk_type::Nr_ != 0)
return;
if((m != uk_type::Mr_ && std::is_same<typename uk_type::ALayout_, Col>::value) ||
(n != uk_type::Nr_ && std::is_same<typename uk_type::BLayout_, Row>::value))
// only k is the fast changing dim of A/B can we do muldiplt m, n
return; return;
test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k); test_ukernel<data_type, ALayout, BLayout>(uk_type{}, mat_a, mat_b, mat_c, alpha, m, n, k);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment