"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "2282531b581860e67f3a405417afa9aa70dcf38b"
Commit 269adde8 authored by root's avatar root
Browse files

fixed a bug

parent a46a17fb
......@@ -137,15 +137,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<CYXPerThreadLoop>{}, Number<KPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
// make_tuple(Number<CYXPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{},
// Number<WPerThread>{}));
make_tuple(Number<CYXPerThreadLoop>{}, Number<1>{}));
constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
// make_tuple(Number<KPerThread>{}, Number<1>{},
// Number<HPerThread>{}, Number<WPerThread>{}));
make_tuple(Number<KPerThread>{}, Number<1>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<CYXPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
......
......@@ -95,11 +95,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// divide block work by [M, N]
#if 1
const auto m_block_work_num = K / Number<KPerBlock>{};
const auto hw_block_work_num = (N * H * W) / (Number<HPerBlock>{} * Number<WPerBlock>{});
const auto h_block_work_num = H / Number<HPerBlock>{};
const auto w_block_work_num = W / Number<WPerBlock>{};
const auto hw_block_work_num = h_block_work_num * w_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hw_block_work_num;
const index_t hw_block_work_id = get_block_1d_id() - k_block_work_id * hw_block_work_num;
const index_t h_block_work_id = hw_block_work_id / w_block_work_num;
const index_t w_block_work_id = hw_block_work_id - h_block_work_id * w_block_work_num;
constexpr auto h_num_threads = HPerBlock / HPerThread;
constexpr auto w_num_threads = WPerBlock / WPerThread;
......@@ -119,8 +124,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
const index_t m_block_data_on_global = k_block_work_id * KPerBlock;
const index_t h_block_data_on_global = hw_block_work_id * HPerBlock;
const index_t w_block_data_on_global = hw_block_work_id * WPerBlock;
const index_t h_block_data_on_global = h_block_work_id * HPerBlock;
const index_t w_block_data_on_global = w_block_work_id * WPerBlock;
// lds max alignment
constexpr auto max_lds_align =
......@@ -187,8 +192,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
ThreadwiseTensorSliceTransferB b_threadwise_transfer(
b_cyx_n_h_w_global_desc,
make_multi_index(
0, 0, h_block_data_on_global + h_thread_id, w_block_data_on_global + w_thread_id));
make_multi_index(0,
0,
h_block_data_on_global + h_thread_id * HPerThread,
w_block_data_on_global + w_thread_id * WPerThread));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
......@@ -426,7 +433,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Float,
decltype(c_k_n_h_w_thread_desc),
decltype(c_k_n_h_w_global_desc),
Sequence<KPerThread, 1, 1, 1>,
Sequence<KPerThread, 1, HPerThread, WPerThread>,
Sequence<3, 2, 0, 1>, // CThreadTransferSrcDstAccessOrder
3, // CThreadTransferSrcDstVectorDim
1, // CThreadTransferDstScalarPerVector,
......
......@@ -47,20 +47,26 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
constexpr auto H = BDesc{}.GetLength(I2);
constexpr auto W = BDesc{}.GetLength(I3);
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m));
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(k, n));
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(m, n));
constexpr auto CYX = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1);
static_for<0, CYX, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) {
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k));
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(e, 0, h, w));
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(k, 0, h, w));
p_c[c_offset] +=
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
p_c[c_offset] +=
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
});
});
});
});
......
......@@ -72,21 +72,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 8;
constexpr index_t WPerBlock = 8;
constexpr index_t CYXPerBlock = 4 * 3 * 3;
constexpr index_t WPerBlock = 16;
constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerThread = 16;
constexpr index_t HPerThread = 1;
constexpr index_t WPerThread = 1;
constexpr index_t CYXPerThread = 4 * 3 * 3;
constexpr index_t WPerThread = 2;
constexpr index_t CYXPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<36, 1>;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
......
......@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
#elif 1
constexpr index_t N = 1;
constexpr index_t C = 4;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment