Commit d3df5eb1 authored by Jing Zhang's avatar Jing Zhang
Browse files

tweak

parent 55c280e4
...@@ -133,10 +133,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -133,10 +133,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto EPerBlock = a_block_mtx.GetLength(I0); constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
constexpr auto KPerThreadSubC = 4; constexpr auto KPerThreadSubC = 4;
constexpr auto HPerThreadSubC = 2;
constexpr auto WPerThreadSubC = 2;
static_assert(KPerThread % KPerThreadSubC == 0, ""); static_assert(KPerThread % KPerThreadSubC == 0, "");
static_assert(HPerThread % 2 == 0, ""); static_assert(HPerThread % HPerThreadSubC == 0, "");
static_assert(WPerThread % 2 == 0, ""); static_assert(WPerThread % WPerThreadSubC == 0, "");
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
...@@ -158,7 +160,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -158,7 +160,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
decltype(b_thread_mtx), decltype(b_thread_mtx),
decltype(c_thread_mtx)>{}; decltype(c_thread_mtx),
HPerThreadSubC,
WPerThreadSubC>{};
// loop over k // loop over k
#pragma unroll #pragma unroll
for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop) for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
...@@ -171,10 +175,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -171,10 +175,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
mMyThreadOffsetA, mMyThreadOffsetA,
p_a_thread); p_a_thread);
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += 2) #pragma unroll
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += HPerThreadSubC)
{ {
#pragma unroll
for(index_t w_begin = 0; w_begin < WPerThread; w_begin += 2) for(index_t w_begin = 0; w_begin < WPerThread; w_begin += WPerThreadSubC)
{ {
threadwise_gemm.Run(p_a_thread, threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple( p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(
......
...@@ -37,6 +37,8 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread ...@@ -37,6 +37,8 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread
template <typename ADesc, template <typename ADesc,
typename BDesc, typename BDesc,
typename CDesc, typename CDesc,
index_t H,
index_t W,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
...@@ -54,11 +56,6 @@ struct ThreadwiseGemm_km_kn_mn_v3 ...@@ -54,11 +56,6 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
// constexpr auto H = BDesc{}.GetLength(I2);
// constexpr auto W = BDesc{}.GetLength(I3);
constexpr auto H = 2;
constexpr auto W = 2;
constexpr auto E = ADesc{}.GetLength(I0); constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1); constexpr auto K = ADesc{}.GetLength(I1);
......
...@@ -111,11 +111,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -111,11 +111,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32; constexpr index_t WoPerBlock = 64;
constexpr index_t EPerBlock = 1; constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock; constexpr index_t EPerThread = EPerBlock;
......
...@@ -64,6 +64,20 @@ int main(int argc, char* argv[]) ...@@ -64,6 +64,20 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 2048;
constexpr index_t WI = 2048;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment