Commit 269adde8 authored by root's avatar root
Browse files

fixed a bug

parent a46a17fb
...@@ -137,15 +137,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -137,15 +137,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<CYXPerThreadLoop>{}, Number<KPerThread>{})); make_tuple(Number<CYXPerThreadLoop>{}, Number<KPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
// make_tuple(Number<CYXPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<CYXPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
// Number<WPerThread>{}));
make_tuple(Number<CYXPerThreadLoop>{}, Number<1>{})); constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
// make_tuple(Number<KPerThread>{}, Number<1>{},
// Number<HPerThread>{}, Number<WPerThread>{}));
make_tuple(Number<KPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()]; FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
......
...@@ -95,11 +95,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -95,11 +95,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// divide block work by [M, N] // divide block work by [M, N]
#if 1 #if 1
const auto m_block_work_num = K / Number<KPerBlock>{}; const auto m_block_work_num = K / Number<KPerBlock>{};
const auto hw_block_work_num = (N * H * W) / (Number<HPerBlock>{} * Number<WPerBlock>{}); const auto h_block_work_num = H / Number<HPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{};
const auto hw_block_work_num = h_block_work_num * w_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hw_block_work_num; const index_t k_block_work_id = get_block_1d_id() / hw_block_work_num;
const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num; const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num;
const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
constexpr auto h_num_threads = HPerBlock / HPerThread; constexpr auto h_num_threads = HPerBlock / HPerThread;
constexpr auto w_num_threads = WPerBlock / WPerThread; constexpr auto w_num_threads = WPerBlock / WPerThread;
...@@ -119,8 +124,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -119,8 +124,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t m_block_data_on_global = k_block_work_id * KPerBlock; const index_t m_block_data_on_global = k_block_work_id * KPerBlock;
const index_t h_block_data_on_global = hw_block_work_id * HPerBlock; const index_t h_block_data_on_global = h_block_work_id * HPerBlock;
const index_t w_block_data_on_global = hw_block_work_id * WPerBlock; const index_t w_block_data_on_global = w_block_work_id * WPerBlock;
// lds max alignment // lds max alignment
constexpr auto max_lds_align = constexpr auto max_lds_align =
...@@ -187,8 +192,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -187,8 +192,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
ThreadwiseTensorSliceTransferB b_threadwise_transfer( ThreadwiseTensorSliceTransferB b_threadwise_transfer(
b_cyx_n_h_w_global_desc, b_cyx_n_h_w_global_desc,
make_multi_index( make_multi_index(0,
0, 0, h_block_data_on_global + h_thread_id, w_block_data_on_global + w_thread_id)); 0,
h_block_data_on_global + h_thread_id * HPerThread,
w_block_data_on_global + w_thread_id * WPerThread));
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
...@@ -426,7 +433,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -426,7 +433,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Float, Float,
decltype(c_k_n_h_w_thread_desc), decltype(c_k_n_h_w_thread_desc),
decltype(c_k_n_h_w_global_desc), decltype(c_k_n_h_w_global_desc),
Sequence<KPerThread, 1, 1, 1>, Sequence<KPerThread, 1, HPerThread, WPerThread>,
Sequence<3, 2, 0, 1>, // CThreadTransferSrcDstAccessOrder Sequence<3, 2, 0, 1>, // CThreadTransferSrcDstAccessOrder
3, // CThreadTransferSrcDstVectorDim 3, // CThreadTransferSrcDstVectorDim
1, // CThreadTransferDstScalarPerVector, 1, // CThreadTransferDstScalarPerVector,
......
...@@ -47,20 +47,26 @@ struct ThreadwiseGemm_km_kn_mn_v3 ...@@ -47,20 +47,26 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0); constexpr auto H = BDesc{}.GetLength(I2);
constexpr auto N = CDesc{}.GetLength(I1); constexpr auto W = BDesc{}.GetLength(I3);
constexpr auto K = ADesc{}.GetLength(I0);
static_for<0, K, 1>{}([&](auto k) { constexpr auto CYX = ADesc{}.GetLength(I0);
static_for<0, M, 1>{}([&](auto m) { constexpr auto K = ADesc{}.GetLength(I1);
static_for<0, N, 1>{}([&](auto n) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m)); static_for<0, CYX, 1>{}([&](auto e) {
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(k, n)); static_for<0, K, 1>{}([&](auto k) {
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(m, n)); static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k));
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(e, 0, h, w));
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(k, 0, h, w));
p_c[c_offset] += p_c[c_offset] +=
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]); inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
});
}); });
}); });
}); });
......
...@@ -72,21 +72,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -72,21 +72,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 8; constexpr index_t HPerBlock = 8;
constexpr index_t WPerBlock = 8; constexpr index_t WPerBlock = 16;
constexpr index_t CYXPerBlock = 4 * 3 * 3; constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 16;
constexpr index_t HPerThread = 1; constexpr index_t HPerThread = 1;
constexpr index_t WPerThread = 1; constexpr index_t WPerThread = 2;
constexpr index_t CYXPerThread = 4 * 3 * 3; constexpr index_t CYXPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<36, 1>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
......
...@@ -82,8 +82,8 @@ int main(int argc, char* argv[]) ...@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
constexpr index_t HI = 8; constexpr index_t HI = 1080;
constexpr index_t WI = 8; constexpr index_t WI = 1920;
constexpr index_t K = 16; constexpr index_t K = 16;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment