Commit d3df5eb1 authored by Jing Zhang's avatar Jing Zhang
Browse files

tweak

parent 55c280e4
......@@ -133,10 +133,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
constexpr auto KPerThreadSubC = 4;
constexpr auto HPerThreadSubC = 2;
constexpr auto WPerThreadSubC = 2;
static_assert(KPerThread % KPerThreadSubC == 0, "");
static_assert(HPerThread % 2 == 0, "");
static_assert(WPerThread % 2 == 0, "");
static_assert(HPerThread % HPerThreadSubC == 0, "");
static_assert(WPerThread % WPerThreadSubC == 0, "");
// thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
......@@ -158,7 +160,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
decltype(c_thread_mtx),
HPerThreadSubC,
WPerThreadSubC>{};
// loop over k
#pragma unroll
for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
......@@ -171,10 +175,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
mMyThreadOffsetA,
p_a_thread);
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += 2)
#pragma unroll
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += HPerThreadSubC)
{
for(index_t w_begin = 0; w_begin < WPerThread; w_begin += 2)
#pragma unroll
for(index_t w_begin = 0; w_begin < WPerThread; w_begin += WPerThreadSubC)
{
threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(
......
......@@ -37,6 +37,8 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread
template <typename ADesc,
typename BDesc,
typename CDesc,
index_t H,
index_t W,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
......@@ -54,11 +56,6 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
// constexpr auto H = BDesc{}.GetLength(I2);
// constexpr auto W = BDesc{}.GetLength(I3);
constexpr auto H = 2;
constexpr auto W = 2;
constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1);
......
......@@ -111,11 +111,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t WoPerBlock = 64;
constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
......
......@@ -64,6 +64,20 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 2048;
constexpr index_t WI = 2048;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 16;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment