Unverified Commit 7337ec25 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Refactor 1010 (#14)

* refactor

* refactor

* change load_tile, update block gemm

* debug

* clean

* clean

* experiment lod

* workaround spilling issue

* clean
parent 7b1a0b7f
......@@ -29,7 +29,7 @@ int main(int argc, char* argv[])
using ODataType = ck::half_t;
ck::index_t Batch = 16;
ck::index_t M0 = 4096;
ck::index_t M0 = 3328;
ck::index_t N0 = 4096;
ck::index_t K0 = 128;
ck::index_t N1 = 128;
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/static_tile_distribution_encoding_helper.hpp"
#include "ck/tile_program/tile/static_tile_distribution_helper.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
......@@ -78,6 +78,11 @@ struct BlockGemmARegBSmemCRegV1
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<NWarp>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<KIterPerWarp>>,
......@@ -86,14 +91,6 @@ struct BlockGemmARegBSmemCRegV1
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<MWarp>,
Tuple<Sequence<NIterPerWarp, NWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<0, 1>>,
Tuple<Sequence<0, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
......@@ -105,51 +102,66 @@ struct BlockGemmARegBSmemCRegV1
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
static_assert(is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::GetTileDistribution()
.GetStaticTileDistributionEncoding())>>,
"wrong!");
#if 0
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent distribution
static_assert(
is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(
ABlockTensorTmp::GetBlockDistribution().GetStaticTensorDistributionEncoding())>>,
"wrong!");
#endif
// construct A-block-tensor from A-Block-tensor-tmp
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.GetThreadBuffer() = a_block_tensor_tmp.GetThreadBuffer();
// construct B-block-window from B-block-distribution
auto b_block_window = make_tile_window(b_block_window_tmp.GetBottomTensorView(),
b_block_window_tmp.GetWindowLengths(),
b_block_window_tmp.GetWindowOrigin(),
b_block_dstr);
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<WG::kN>{}, Number<WG::kK>{}),
b_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using Array will cause register spill
Array<Array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
StaticallyIndexedArray<StaticallyIndexedArray<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// check C-block-distribution
static_assert(is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::GetTileDistribution()
.GetStaticTileDistributionEncoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
......@@ -161,24 +173,18 @@ struct BlockGemmARegBSmemCRegV1
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetSlicedThreadData(
a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetYSlicedThreadData(
merge_sequences(Sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.GetThreadBuffer() =
detail::load_sliced_thread_data_from_tile_window(
b_block_window,
MultiIndex<2 + BWarpDstr::NDimY>{nIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, b_warp_y_lengths));
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetSlicedThreadData(
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetYSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths));
......@@ -186,7 +192,7 @@ struct BlockGemmARegBSmemCRegV1
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.SetSlicedThreadData(
c_block_tensor.SetYSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.GetThreadBuffer());
......@@ -223,6 +229,11 @@ struct BlockGemmARegBSmemCRegV1
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<NWarp>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<KIterPerWarp>>,
......@@ -231,14 +242,6 @@ struct BlockGemmARegBSmemCRegV1
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<MWarp>,
Tuple<Sequence<NIterPerWarp, NWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<0, 1>>,
Tuple<Sequence<0, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
......@@ -250,50 +253,64 @@ struct BlockGemmARegBSmemCRegV1
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
#if 0
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent distribution
static_assert(
is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(
ABlockTensorTmp::GetBlockDistribution().GetStaticTensorDistributionEncoding())>>,
"wrong!");
#endif
// construct A-block-tensor from A-Block-tensor-tmp
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.GetThreadBuffer() = a_block_tensor_tmp.GetThreadBuffer();
// construct B-block-window from B-block-distribution
auto b_block_window = make_tile_window(b_block_window_tmp.GetBottomTensorView(),
b_block_window_tmp.GetWindowLengths(),
b_block_window_tmp.GetWindowOrigin(),
b_block_dstr);
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<WG::kN>{}, Number<WG::kK>{}),
b_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using Array will cause register spill
Array<Array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
StaticallyIndexedArray<StaticallyIndexedArray<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// Construct C-Block-Tensor
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
......@@ -305,24 +322,18 @@ struct BlockGemmARegBSmemCRegV1
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetSlicedThreadData(
a_warp_tensor.GetThreadBuffer() = a_block_tensor.GetYSlicedThreadData(
merge_sequences(Sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.GetThreadBuffer() =
detail::load_sliced_thread_data_from_tile_window(
b_block_window,
MultiIndex<2 + BWarpDstr::NDimY>{nIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, b_warp_y_lengths));
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetSlicedThreadData(
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetYSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths));
......@@ -330,7 +341,7 @@ struct BlockGemmARegBSmemCRegV1
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.SetSlicedThreadData(
c_block_tensor.SetYSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.GetThreadBuffer());
......
......@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/static_tile_distribution_encoding_helper.hpp"
#include "ck/tile_program/tile/static_tile_distribution_helper.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/load_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
......@@ -79,68 +79,84 @@ struct BlockGemmASmemBSmemCRegV1
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<NWarp>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<1, 0>>,
Tuple<Sequence<1, 0>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
constexpr auto b_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<MWarp>,
Tuple<Sequence<NIterPerWarp, NWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<0, 1>>,
Tuple<Sequence<0, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
Tuple<Sequence<1, 2>>,
Tuple<Sequence<1, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<WG::kM>{}, Number<WG::kK>{}),
a_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using Array will cause register spill
Array<Array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
StaticallyIndexedArray<StaticallyIndexedArray<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<WG::kN>{}, Number<WG::kK>{}),
b_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
static_assert(is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::GetTileDistribution()
.GetStaticTileDistributionEncoding())>>,
"wrong!");
#if 0 // FIXME: using Array will cause register spill
Array<Array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
// construct A/B-block-window from A/B-block-distribution
auto a_block_window = make_tile_window(a_block_window_tmp.GetBottomTensorView(),
a_block_window_tmp.GetWindowLengths(),
a_block_window_tmp.GetWindowOrigin(),
a_block_dstr);
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
StaticallyIndexedArray<StaticallyIndexedArray<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
auto b_block_window = make_tile_window(b_block_window_tmp.GetBottomTensorView(),
b_block_window_tmp.GetWindowLengths(),
b_block_window_tmp.GetWindowOrigin(),
b_block_dstr);
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
......@@ -148,27 +164,16 @@ struct BlockGemmASmemBSmemCRegV1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
AWarpTensor a_warp_tensor;
a_warp_tensor.GetThreadBuffer() = detail::load_sliced_thread_data_from_tile_window(
a_block_window,
MultiIndex<2 + AWarpDstr::NDimY>{mIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths));
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.GetThreadBuffer() =
detail::load_sliced_thread_data_from_tile_window(
b_block_window,
MultiIndex<2 + BWarpDstr::NDimY>{nIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, b_warp_y_lengths));
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetSlicedThreadData(
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetYSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths));
......@@ -176,7 +181,7 @@ struct BlockGemmASmemBSmemCRegV1
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.SetSlicedThreadData(
c_block_tensor.SetYSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.GetThreadBuffer());
......@@ -213,22 +218,84 @@ struct BlockGemmASmemBSmemCRegV1
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto a_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<NWarp>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<1, 0>>,
Tuple<Sequence<1, 0>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
constexpr auto b_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<MWarp>,
Tuple<Sequence<NIterPerWarp, NWarp>, Sequence<KIterPerWarp>>,
Tuple<Sequence<0, 1>>,
Tuple<Sequence<0, 1>>,
Sequence<1, 2>,
Sequence<0, 0>>{};
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
// construct A-warp-window
auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<WG::kM>{}, Number<WG::kK>{}),
a_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using Array will cause register spill
Array<Array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
{a_warp_window_tmp}};
for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
StaticallyIndexedArray<StaticallyIndexedArray<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.GetBottomTensorView(),
make_tuple(Number<WG::kN>{}, Number<WG::kK>{}),
b_block_window_tmp.GetWindowOrigin() + MultiIndex<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using Array will cause register spill
Array<Array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
StaticallyIndexedArray<StaticallyIndexedArray<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
static_assert(is_same_v<CDataType, typename WG::CDataType>, "wrong!");
// Construct C-Block-Tensor
constexpr auto c_block_outer_dstr_encoding = StaticTileDistributionEncoding<
Sequence<>,
Tuple<Sequence<MIterPerWarp, MWarp>, Sequence<NIterPerWarp, NWarp>>,
......@@ -237,69 +304,28 @@ struct BlockGemmASmemBSmemCRegV1
Sequence<1, 2>,
Sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// construct A/B-block-window from A/B-block-distribution
auto a_block_window = make_tile_window(a_block_window_tmp.GetBottomTensorView(),
a_block_window_tmp.GetWindowLengths(),
a_block_window_tmp.GetWindowOrigin(),
a_block_dstr);
auto b_block_window = make_tile_window(b_block_window_tmp.GetBottomTensorView(),
b_block_window_tmp.GetWindowLengths(),
b_block_window_tmp.GetWindowOrigin(),
b_block_dstr);
static_assert(is_same_v<CDataType, typename WG::CDataType>, "wrong!");
// Construct C-Block-Tensor
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.GetYs2DDescriptor().GetLengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
AWarpTensor a_warp_tensor;
a_warp_tensor.GetThreadBuffer() = detail::load_sliced_thread_data_from_tile_window(
a_block_window,
MultiIndex<2 + AWarpDstr::NDimY>{mIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, a_warp_y_lengths));
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
BWarpTensor b_warp_tensor;
b_warp_tensor.GetThreadBuffer() =
detail::load_sliced_thread_data_from_tile_window(
b_block_window,
MultiIndex<2 + BWarpDstr::NDimY>{nIter, kIter, 0},
merge_sequences(Sequence<1, 1>{}, b_warp_y_lengths));
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
......@@ -313,7 +339,7 @@ struct BlockGemmASmemBSmemCRegV1
else
{
// c += a * b
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetSlicedThreadData(
c_warp_tensor.GetThreadBuffer() = c_block_tensor.GetYSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths));
......@@ -321,7 +347,7 @@ struct BlockGemmASmemBSmemCRegV1
}
// write C warp tensor into C block tensor
c_block_tensor.SetSlicedThreadData(
c_block_tensor.SetYSlicedThreadData(
merge_sequences(Sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(Sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.GetThreadBuffer());
......
......@@ -6,7 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tile_program/tile/static_distributed_tensor.hpp"
#include "ck/tile_program/tile/static_tile_distribution_encoding_helper.hpp"
#include "ck/tile_program/tile/static_tile_distribution_helper.hpp"
#include "ck/tile_program/tile/distributed_tile_sweep.hpp"
namespace ck {
......
......@@ -19,6 +19,7 @@ namespace tile_program {
// detail used by tile-programming APIs(), not supposed to be used directly
namespace detail {
// TODO: deprecate
// "Y dimension": Y dimensions inside TileWindowWithStaticDistribution
// input:
// y_slice_origin: starting slice origin of Y dimension
......@@ -177,14 +178,144 @@ load_tile(TileWindowWithStaticDistribution<BottomTensorView_, WindowLengths_, Ti
constexpr auto tile_dstr = TileDstr{};
constexpr index_t NDimY = tile_dstr.GetYs2DDescriptor().GetNumOfDimension();
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto dstr_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths());
dstr_tensor.GetThreadBuffer() = detail::load_sliced_thread_data_from_tile_window(
tile_window, MultiIndex<NDimY>{0}, to_sequence(tile_dstr.GetYs2DDescriptor().GetLengths()));
constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
static_assert(TileWindow::HasStaticTileDistribution(),
"wrong! assume static tile distribution");
constexpr auto tmp = [&thread_tensor_lengths_ys]() {
const auto [ys_vector_lengths, ys_vector_strides] =
TileWindow::GetWindowAdaptorYsSafeVectorLengthStrides();
index_t VectorDimY = 0;
index_t ScalarPerVector = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector)
{
ScalarPerVector = ys_vector_lengths[i];
VectorDimY = i;
}
}
return make_tuple(VectorDimY, ScalarPerVector);
}();
constexpr index_t VectorDimY = tmp.template At<0>();
constexpr index_t ScalarPerVector = tmp.template At<1>();
// FIXME
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, Number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
using vector_t = typename vector_type_t::type;
using SFC_Ys = SpaceFillingCurve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Ys::GetNumOfAccess();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
#if 1 // debug
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// read from bottom tensor
const vector_t vec_value =
tile_window.GetBottomTensorView().template GetVectorizedElements<vector_t>(
tile_window.GetBottomTensorThreadCoordinate());
const vector_type_t vec{vec_value};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// write into distributed tensor
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
dst_tensor.GetThreadBuffer().template At<d>() = vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
// move thread coordinate back to origin
{
constexpr auto idx_diff_ys = SFC_Ys::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
#else
auto tile_window_tmp = tile_window;
// loop over thread tensor space [y0, y1, ...]
static_for<0, num_access, 1>{}([&](auto iAccess) {
// read from bottom tensor
const vector_t vec_value =
tile_window.GetBottomTensorView().template GetVectorizedElements<vector_t>(
tile_window_tmp.GetBottomTensorThreadCoordinate());
const vector_type_t vec{vec_value};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::GetIndex(iAccess);
// write into distributed tensor
static_for<0, ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
[&](auto jj) {
return jj == VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
Number<NDimY>{});
constexpr index_t d = tile_dstr.GetYs2DDescriptor().CalculateOffset(idx_ys);
dst_tensor.GetThreadBuffer().template At<d>() = vec.template AsType<DataType>()[j];
});
// move thread coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto idx_diff_ys = SFC_Ys::GetForwardStep(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(Array<index_t, NDimP>{0}, idx_diff_ys);
tile_window_tmp.MoveWindowAdaptorAndBottomTensorThreadCoordinate(idx_diff_ps_ys);
}
});
#endif
return dstr_tensor;
return dst_tensor;
}
} // namespace tile_program
......
......@@ -10,271 +10,31 @@
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/static_tile_distribution_helper.hpp"
#include "ck/tile_program/tile/tile_window.hpp"
#include "ck/tile_program/tile/static_distributed_tensor.hpp"
namespace ck {
namespace tile_program {
namespace detail {
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<Sequence<x, xs...>,
Sequence<m, ms...>,
Sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<Sequence<xs...>, Sequence<ms...>, Sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
static constexpr auto slice_length =
std::conditional_t<m, Number<math::gcd(x, slice_size)>, Number<x>>::value;
using dim_lengths =
typename sequence_merge<Sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<Sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, Sequence<slice_size / slice_length>, Sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, Number<_flag>, Number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, Number<id>, Number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, Number<old_scan::split_idx>, Number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<Sequence<x>, Sequence<m>, Sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, Number<math::gcd(x, slice_size)>, Number<x>>::value;
using dim_lengths = Sequence<slice_length>;
using dim_slices = Sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, Sequence<slice_size / slice_length>, Sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, Number<_flag>, Number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, Number<id>, Number<0>>::value;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and Number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return Tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::Size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
Number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::Size(), 1>::type{})
{
static_assert(Seq::Size() == Mask::Size());
using sliced_type =
reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
Number<sliced_type::split_idx>{});
}
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// also, sliced along y_dim need be the first dim of current dim.
// Multiply Y dim before sliced dim does not make sense
//
// e.g
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
// totally 16 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
// |--> slice along this P dim, will split threads, not supported
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
// |--> slice along this Y dim, but this Y sim need to split into 2
// subdime
// the P dim in the left is 1, means actually not crossing P
//
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
__host__ __device__ constexpr auto slice_distribution_from_x(
Distribution, Sequence<XSliceBegins...> x_slice_begins, Sequence<XSliceEnds...> x_slice_ends)
{
// NOTE: this function need to be called under constexpr context,
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
using Encoding = decltype(Distribution::GetStaticTileDistributionEncoding());
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
constexpr auto src_h_prefix_sum = Encoding::Detail::GetHDimLengthsPrefixSum();
constexpr auto src_y_info = Encoding::Detail::GetSortedYInfo();
constexpr auto src_y_dims = src_y_info[Number<0>{}];
constexpr auto src_y_maps = src_y_info[Number<1>{}];
constexpr auto src_y_prefix_sum = src_y_info[Number<2>{}];
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
{
auto y_slice_sorted_origins = make_zero_multi_index<Distribution::NDimY>();
auto y_slice_lengths =
to_array<index_t, Distribution::NDimY>(Distribution{}.GetYs2DDescriptor().GetLengths());
// This lambda will modify some value outside, so c++ will not treat return value as
// constexpr
// TODO: ugly
auto new_h_lengths = transform_tuples(
[&](auto h_len, auto id) {
constexpr auto sliced_h =
reverse_slice_sequence(h_len, Number<x_slice_lengths[id]>{});
constexpr auto sliced_h_lens = sliced_h[Number<0>{}];
constexpr auto sliced_h_index = sliced_h[Number<2>{}];
// update y_slice_lengths
constexpr auto uniformed_h_index = sliced_h_index + Number<src_h_prefix_sum[id]>{};
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.Size(),
"not sliced at y dim, please check");
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[found_y_index - i]) =
sliced_h_lens[sliced_h_index - i];
});
// TODO: add validations not across p dim
// NOTE: this y_origin is for all dims, not only current dim
// will later use pick to select target dim
constexpr auto y_origin = [&]() {
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
h_trans.CalculateLowerIndex(h_origin_, Sequence<x_slice_begins[id].value>{});
auto y_origin_ = make_zero_multi_index<Distribution::NDimY>();
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
});
return y_origin_;
}();
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
src_y_prefix_sum[id + 1],
1>::type{};
set_container_subset(
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
return sliced_h_lens;
},
typename Encoding::HsLengthss{},
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::Size(), 1>::type{});
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
}
();
return sliced_hlen_yidx_ylen;
}
} // namespace detail
template <typename StaticDistributedTensor_, index_t... SliceBegins, index_t... SliceEnds>
__host__ __device__ constexpr auto get_slice_tile(const StaticDistributedTensor_& tile,
Sequence<SliceBegins...> slice_begins,
Sequence<SliceEnds...> slice_ends)
{
using Distribution = decltype(StaticDistributedTensor_::GetTileDistribution());
using Encoding = decltype(Distribution::GetStaticTileDistributionEncoding());
using DataType = typename StaticDistributedTensor_::DataType;
constexpr auto sliced_hlen_yidx_ylen =
constexpr auto sliced_dstr_yidx_ylen =
detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[Number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[Number<1>{}];
constexpr auto sliced_y_origins_size = sliced_y_origins_array.Size();
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[Number<2>{}];
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.Size();
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
using SlicedEnc =
StaticTileDistributionEncoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor,
typename Encoding::Ys2RHsMinor>;
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>();
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>();
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>();
auto sliced_tensor =
make_static_distributed_tensor<DataType>(make_static_tile_distribution(SlicedEnc{}));
auto sliced_tensor = make_static_distributed_tensor<DataType>(sliced_dstr);
sliced_tensor.GetThreadBuffer() = tile.GetSlicedThreadData(sliced_y_origins, sliced_y_lengths);
sliced_tensor.GetThreadBuffer() = tile.GetYSlicedThreadData(sliced_y_origins, sliced_y_lengths);
return sliced_tensor;
}
......@@ -290,17 +50,14 @@ __host__ __device__ constexpr auto set_slice_tile(DstStaticDistributedTensor_& d
{
using DstDistribution = decltype(DstStaticDistributedTensor_::GetTileDistribution());
constexpr auto sliced_hlen_yidx_ylen =
constexpr auto sliced_dstr_yidx_ylen =
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[Number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[Number<1>{}];
constexpr auto sliced_y_origins_size = sliced_y_origins_array.Size();
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[Number<2>{}];
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.Size();
constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template At<0>();
constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template At<1>();
constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template At<2>();
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
static_assert(is_same_v<decltype(sliced_dstr), DstDistribution>, "wrong!");
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer());
}
......
......@@ -58,7 +58,7 @@ struct StaticDistributedTensor
}
template <index_t... YSliceOrigins, index_t... YSliceLengths>
__host__ __device__ auto GetSlicedThreadData(Sequence<YSliceOrigins...>,
__host__ __device__ auto GetYSlicedThreadData(Sequence<YSliceOrigins...>,
Sequence<YSliceLengths...>) const
{
static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY &&
......@@ -85,7 +85,7 @@ struct StaticDistributedTensor
}
template <index_t... YSliceOrigins, index_t... YSliceLengths, index_t NSlicedData>
__host__ __device__ void SetSlicedThreadData(
__host__ __device__ void SetYSlicedThreadData(
Sequence<YSliceOrigins...>,
Sequence<YSliceLengths...>,
const StaticBuffer<AddressSpaceEnum::Vgpr, DataType, NSlicedData, true>& sliced_thread_data)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -70,6 +69,21 @@ struct StaticTileDistributionEncoding
static constexpr auto rhs_lengthss_ =
to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_));
// ys_lengths_
static constexpr auto ys_lengths_ = [] {
Array<index_t, NDimY> ys_lengths_tmp{-1};
for(index_t i = 0; i < NDimY; i++)
{
index_t rh_major = ys_to_rhs_major_[i];
index_t rh_minor = ys_to_rhs_minor_[i];
ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
}
return ys_lengths_tmp;
}();
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
static constexpr auto rhs_major_minor_to_ys_ = [] {
Array<Array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
......@@ -317,6 +331,10 @@ struct StaticTileDistributionEncoding
print(rhs_lengthss_);
printf(", ");
//
printf("ys_lengths_: ");
print(ys_lengths_);
printf(", ");
//
printf("rhs_major_minor_to_ys_: ");
print(rhs_major_minor_to_ys_);
printf(", ");
......
......@@ -11,6 +11,22 @@ namespace ck {
namespace tile_program {
namespace detail {
template <typename Distribution>
__host__ __device__ auto get_partition_index(Distribution)
{
// only support warp-tile and block-tile
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
if constexpr(Distribution::NDimP == 1)
{
return Array<index_t, 1>{get_lane_id()};
}
else if constexpr(Distribution::NDimP == 2)
{
return Array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
template <typename OuterDstr, typename InnerDstr>
__host__ __device__ constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
{
......@@ -351,6 +367,246 @@ make_reduce_tile_distribution_encoding(InDstr, Sequence<InReduceDimXs...> reduce
remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
}
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<Sequence<x, xs...>,
Sequence<m, ms...>,
Sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<Sequence<xs...>, Sequence<ms...>, Sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
static constexpr auto slice_length =
std::conditional_t<m, Number<math::gcd(x, slice_size)>, Number<x>>::value;
using dim_lengths =
typename sequence_merge<Sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<Sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, Sequence<slice_size / slice_length>, Sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, Number<_flag>, Number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, Number<id>, Number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, Number<old_scan::split_idx>, Number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<Sequence<x>, Sequence<m>, Sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, Number<math::gcd(x, slice_size)>, Number<x>>::value;
using dim_lengths = Sequence<slice_length>;
using dim_slices = Sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, Sequence<slice_size / slice_length>, Sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, Number<_flag>, Number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, Number<id>, Number<0>>::value;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and Number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return Tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::Size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
Number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::Size(), 1>::type{})
{
static_assert(Seq::Size() == Mask::Size());
using sliced_type =
reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
Number<sliced_type::split_idx>{});
}
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// also, sliced along y_dim need be the first dim of current dim.
// Multiply Y dim before sliced dim does not make sense
//
// e.g
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
// totally 16 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
// |--> slice along this P dim, will split threads, not supported
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
// |--> slice along this Y dim, but this Y sim need to split into 2
// subdime
// the P dim in the left is 1, means actually not crossing P
//
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
__host__ __device__ constexpr auto slice_distribution_from_x(
Distribution, Sequence<XSliceBegins...> x_slice_begins, Sequence<XSliceEnds...> x_slice_ends)
{
// NOTE: this function need to be called under constexpr context,
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
using Encoding = decltype(Distribution::GetStaticTileDistributionEncoding());
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
constexpr auto src_h_prefix_sum = Encoding::Detail::GetHDimLengthsPrefixSum();
constexpr auto src_y_info = Encoding::Detail::GetSortedYInfo();
constexpr auto src_y_dims = src_y_info[Number<0>{}];
constexpr auto src_y_maps = src_y_info[Number<1>{}];
constexpr auto src_y_prefix_sum = src_y_info[Number<2>{}];
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
{
auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
auto y_slice_lengths = Encoding::Detail::ys_lengths_;
// This lambda will modify some value outside, so c++ will not treat return value as
// constexpr
// TODO: ugly
auto new_h_lengths = transform_tuples(
[&](auto h_len, auto id) {
constexpr auto sliced_h =
reverse_slice_sequence(h_len, Number<x_slice_lengths[id]>{});
constexpr auto sliced_h_lens = sliced_h[Number<0>{}];
constexpr auto sliced_h_index = sliced_h[Number<2>{}];
// update y_slice_lengths
constexpr auto uniformed_h_index = sliced_h_index + Number<src_h_prefix_sum[id]>{};
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.Size(),
"not sliced at y dim, please check");
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[found_y_index - i]) =
sliced_h_lens[sliced_h_index - i];
});
// TODO: add validations not across p dim
// NOTE: this y_origin is for all dims, not only current dim
// will later use pick to select target dim
constexpr auto y_origin = [&]() {
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
h_trans.CalculateLowerIndex(h_origin_, Sequence<x_slice_begins[id].value>{});
auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
});
return y_origin_;
}();
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
src_y_prefix_sum[id + 1],
1>::type{};
set_container_subset(
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
return sliced_h_lens;
},
typename Encoding::HsLengthss{},
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::Size(), 1>::type{});
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
}
();
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[Number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[Number<1>{}];
constexpr auto sliced_y_origins_size = sliced_y_origins_array.Size();
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[Number<2>{}];
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.Size();
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
return make_tuple(
make_static_tile_distribution(
StaticTileDistributionEncoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor,
typename Encoding::Ys2RHsMinor>{}),
sliced_y_origins,
sliced_y_lengths);
}
} // namespace detail
} // namespace tile_program
} // namespace ck
......@@ -8,6 +8,7 @@
#include "ck/tensor_description/tensor_adaptor_coordinate.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/static_tile_distribution_helper.hpp"
namespace ck {
namespace tile_program {
......@@ -57,10 +58,48 @@ struct TileWindowWithStaticDistribution
window_origin_{window_origin},
bottom_tensor_thread_coord_{},
tile_dstr_{tile_distribution},
window_adaptor_thread_coord_{
window_adaptor_thread_coord_{}
{
#if 0 // debug
// only support warp-tile and block-tile
static_assert(TileDstr::NDimP == 1 or TileDstr::NDimP == 2, "wrong!");
if constexpr(TileDstr::NDimP == 1)
{
window_adaptor_thread_coord_ = make_tensor_adaptor_coordinate(
tile_distribution.GetPsYs2XsAdaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(TileDstr::NDimP == 2)
{
window_adaptor_thread_coord_ =
make_tensor_adaptor_coordinate(tile_distribution.GetPsYs2XsAdaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0})}
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#elif 0
// only support warp-tile and block-tile
static_assert(TileDstr::NDimP == 1 or TileDstr::NDimP == 2, "wrong!");
if constexpr(TileDstr::NDimP == 1)
{
window_adaptor_thread_coord_ = make_tensor_adaptor_coordinate(
tile_distribution.GetPsYs2XsAdaptor(),
container_concat(Array<index_t, 1>{get_lane_id()},
Array<index_t, TileDstr::NDimY>{0}));
}
else if constexpr(TileDstr::NDimP == 2)
{
window_adaptor_thread_coord_ = make_tensor_adaptor_coordinate(
tile_distribution.GetPsYs2XsAdaptor(),
container_concat(Array<index_t, 2>{get_warp_id(), get_lane_id()},
Array<index_t, TileDstr::NDimY>{0}));
}
#else
window_adaptor_thread_coord_ = make_tensor_adaptor_coordinate(
tile_distribution.GetPsYs2XsAdaptor(),
container_concat(detail::get_partition_index(tile_distribution),
Array<index_t, TileDstr::NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx;
for(index_t i = 0; i < NDimBottomTensor; ++i)
......
......@@ -27,7 +27,10 @@ __device__ index_t get_block_1d_id() { return blockIdx.x; }
// Use these instead
__device__ index_t get_lane_id() { return __lane_id(); }
__device__ index_t get_warp_id() { return threadIdx.x / get_warp_size(); }
__device__ index_t get_warp_id()
{
return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size());
}
__device__ index_t get_thread_id() { return threadIdx.x; }
......
......@@ -256,6 +256,8 @@ struct Tuple<>
// FIXME: remove
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ void Print() const { printf("Tuple{size: 0, data: []}"); }
};
template <typename... Xs>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment