Commit 33ceea62 authored by carlushuang's avatar carlushuang
Browse files

merge to convert address

parent 1ba8a08f
...@@ -7,6 +7,11 @@ ...@@ -7,6 +7,11 @@
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
// if set to 1, slightly more instructions generated to calculate address
#ifndef MERGE_2D_013425
#define MERGE_2D_013425 0
#endif
enum class matrix_core_inst_enum enum class matrix_core_inst_enum
{ {
MFMA_32x32x8_F16 = 0, MFMA_32x32x8_F16 = 0,
...@@ -213,19 +218,35 @@ struct matrix_core_swizzle_kernel ...@@ -213,19 +218,35 @@ struct matrix_core_swizzle_kernel
constexpr index_t Kr_y = Kr / Kr_p; constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution( return make_static_tile_distribution(
#if MERGE_2D_013425
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>,// 0 sequence<1>,// 0 R
// major 1 2
// minor 0 1 2 0 1 2 3
tuple<sequence<Nr_y, Nr_p, Nw>, sequence<Kr_y, Kr_p, Kw, Kv>>, // H
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<2, 1>>, // p major
tuple<sequence<1 , 1>, sequence<2, 2>>, // p minor
// Nr_y Kr_y Kv
sequence<1, 2, 2>, // Y major
sequence<0, 0, 3>>{}); // y minor
#else
tile_distribution_encoding<
sequence<1>,// 0 R
// major 1 2 3 // major 1 2 3
// minor 0 1 0 1 0 1 2 // minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>, tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>, // H
// Nr_p, Kr_p Kw Nw // Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<3, 3>>, tuple<sequence<1 , 2>, sequence<3, 3>>, // p major
tuple<sequence<1 , 1>, sequence<0, 1>>, tuple<sequence<1 , 1>, sequence<0, 1>>, // p minor
// Nr_y Kr_y Kv // Nr_y Kr_y Kv
sequence<1, 2, 3>, sequence<1, 2, 3>, // Y major
sequence<0, 0, 2>>{}); sequence<0, 0, 2>>{}); // y minor
#endif
// clang-format on // clang-format on
} }
} }
...@@ -291,18 +312,39 @@ struct matrix_core_swizzle_kernel ...@@ -291,18 +312,39 @@ struct matrix_core_swizzle_kernel
} }
else else
{ {
#if MERGE_2D_013425
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
// constexpr index_t waveflatten = kw*nw*kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, number<kw>{}, number<nw>{}, number<kv>{}),
number<Alignment>{}); // control vector load
auto tmp_1 = transform_tensor_view(
tmp,
make_tuple(
make_merge_transform(make_tuple(nr, number<nw>{})),
make_merge_transform(make_tuple(kr, number<kw>{}, number<kv>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1;
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment; constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten = kw * nw * kv; constexpr index_t waveflatten = kw * nw * kv;
const index_t kr = a_.k / (k1 * k2); const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw; const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>( auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst, p_dst,
make_tuple(nr, kr, waveflatten), make_tuple(nr, kr, waveflatten),
number<Alignment>{}); // control vector load number<Alignment>{}); // control vector load
return tmp; return tmp;
#endif
} }
}(); }();
...@@ -333,19 +375,27 @@ struct matrix_core_swizzle_kernel ...@@ -333,19 +375,27 @@ struct matrix_core_swizzle_kernel
} }
else else
{ {
#if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist());
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment; constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten_tile = kw * nw * kv; constexpr index_t waveflatten_tile = kw * nw * kv;
constexpr index_t nr_tile = NPerBlock / nw; constexpr index_t nr_tile = NPerBlock / nw;
constexpr index_t kr_tile = KPerBlock / (kw * kv); constexpr index_t kr_tile = KPerBlock / (kw * kv);
return make_tile_window(dst_view, return make_tile_window(dst_view,
make_tuple(number<nr_tile>{}, make_tuple(number<nr_tile>{},
number<kr_tile>{}, number<kr_tile>{},
number<waveflatten_tile>{}), number<waveflatten_tile>{}),
{i_n * nr_tile, i_k * kr_tile, 0}, {i_n * nr_tile, i_k * kr_tile, 0},
get_dst_dist()); get_dst_dist());
#endif
} }
}(); }();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment