Commit bf214665 authored by carlushuang's avatar carlushuang
Browse files

add b_nr_kr_waveflatten pattern

parent 22ab193c
...@@ -35,6 +35,18 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, ...@@ -35,6 +35,18 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
auto k = Kernel(a); auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k); float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time; return ave_time;
} }
} }
...@@ -66,6 +78,18 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, ...@@ -66,6 +78,18 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
auto k = Kernel(a); auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k); float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time; return ave_time;
} }
} }
......
...@@ -32,10 +32,13 @@ struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16> ...@@ -32,10 +32,13 @@ struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
template <matrix_core_inst_enum Inst> template <matrix_core_inst_enum Inst>
using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type; using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type;
// TODO: in below permute pattern, the last 3 dim is within wave
enum class matrix_core_permute_style enum class matrix_core_permute_style
{ {
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
}; };
// assume this is B matrix, originally we have batch*n*k // assume this is B matrix, originally we have batch*n*k
...@@ -81,6 +84,9 @@ struct matrix_core_swizzle_kernel ...@@ -81,6 +84,9 @@ struct matrix_core_swizzle_kernel
using harg = matrix_core_swizzle_host_args; using harg = matrix_core_swizzle_host_args;
static constexpr int BLOCK_SIZE = BLOCK_SIZE_; static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = 4;
static constexpr int WavesPerBlock_K = 1;
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
static constexpr int NPerBlock = NPerBlock_; static constexpr int NPerBlock = NPerBlock_;
static constexpr int KPerBlock = KPerBlock_; static constexpr int KPerBlock = KPerBlock_;
static constexpr matrix_core_permute_style pstyle = pstyle_; static constexpr matrix_core_permute_style pstyle = pstyle_;
...@@ -171,7 +177,7 @@ struct matrix_core_swizzle_kernel ...@@ -171,7 +177,7 @@ struct matrix_core_swizzle_kernel
sequence<0, 0, 0>>{}); sequence<0, 0, 0>>{});
// clang-format on // clang-format on
} }
else else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{ {
// clang-format off // clang-format off
return make_static_tile_distribution( return make_static_tile_distribution(
...@@ -189,6 +195,39 @@ struct matrix_core_swizzle_kernel ...@@ -189,6 +195,39 @@ struct matrix_core_swizzle_kernel
sequence<0, 0, 0>>{}); sequence<0, 0, 0>>{});
// clang-format on // clang-format on
} }
else
{
// clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>,
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<3, 3>>,
tuple<sequence<1 , 1>, sequence<0, 1>>,
// Nr_y Kr_y Kv
sequence<1, 2, 3>,
sequence<0, 0, 2>>{});
// clang-format on
}
} }
__device__ void operator()(karg a_) __device__ void operator()(karg a_)
...@@ -242,7 +281,7 @@ struct matrix_core_swizzle_kernel ...@@ -242,7 +281,7 @@ struct matrix_core_swizzle_kernel
number<Alignment>{}); // control vector load number<Alignment>{}); // control vector load
return tmp; return tmp;
} }
else else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{ {
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>( auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst, p_dst,
...@@ -250,6 +289,21 @@ struct matrix_core_swizzle_kernel ...@@ -250,6 +289,21 @@ struct matrix_core_swizzle_kernel
number<Alignment>{}); // control vector load number<Alignment>{}); // control vector load
return tmp; return tmp;
} }
else
{
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten = kw * nw * kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, waveflatten),
number<Alignment>{}); // control vector load
return tmp;
}
}(); }();
auto dst_window = [&]() { auto dst_window = [&]() {
...@@ -265,7 +319,7 @@ struct matrix_core_swizzle_kernel ...@@ -265,7 +319,7 @@ struct matrix_core_swizzle_kernel
{i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0}, {i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0},
get_dst_dist()); get_dst_dist());
} }
else else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{ {
return make_tile_window(dst_view, return make_tile_window(dst_view,
make_tuple(number<n0_tile>{}, make_tuple(number<n0_tile>{},
...@@ -277,6 +331,22 @@ struct matrix_core_swizzle_kernel ...@@ -277,6 +331,22 @@ struct matrix_core_swizzle_kernel
{i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0}, {i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0},
get_dst_dist()); get_dst_dist());
} }
else
{
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten_tile = kw * nw * kv;
constexpr index_t nr_tile = NPerBlock / nw;
constexpr index_t kr_tile = KPerBlock / (kw * kv);
return make_tile_window(dst_view,
make_tuple(number<nr_tile>{},
number<kr_tile>{},
number<waveflatten_tile>{}),
{i_n * nr_tile, i_k * kr_tile, 0},
get_dst_dist());
}
}(); }();
// actual load store // actual load store
......
...@@ -258,8 +258,49 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -258,8 +258,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
}; };
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL #ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 // batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if(rank == 7 && (arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") || if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6"))) arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")))
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
auto nr = shape[1];
auto nw = shape[2];
auto kr = shape[3];
auto kw = shape[4];
auto kv = shape[5];
a.n = nr * nw;
a.k = kr * kw * kv;
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
else
{ {
matrix_core_swizzle_traits t; matrix_core_swizzle_traits t;
t.data_type = data_type; t.data_type = data_type;
...@@ -271,8 +312,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -271,8 +312,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
a.batch = shape[0]; a.batch = shape[0];
a.n = shape[1] * shape[2] * shape[3]; a.n = shape[1] * shape[2] * shape[3];
a.k = shape[4] * shape[5] * shape[6]; a.k = shape[4] * shape[5] * shape[6];
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 && shape[4] % 8 == 0 && if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
shape[1] % 2 == 0) shape[4] % 8 == 0 && shape[1] % 2 == 0)
{ {
// 32x32x8 inst // 32x32x8 inst
// perm=0,1,4,2,5,3,6 // perm=0,1,4,2,5,3,6
...@@ -301,6 +342,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -301,6 +342,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ave_time = run_permute(); ave_time = run_permute();
} }
} }
}
else else
#endif #endif
{ {
......
...@@ -15,6 +15,9 @@ $EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS ...@@ -15,6 +15,9 @@ $EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS $EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS $EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS $EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=2,8,16,8,4,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
$EXE -prec=fp16 -shape=1,24,32,16,2,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
echo "------------------------------------------------------------------" echo "------------------------------------------------------------------"
for prec in "fp8" "fp16" "fp32" ; do for prec in "fp8" "fp16" "fp32" ; do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment