Unverified Commit 6491acda authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

add tensor slicing API (#7)



* add tensor slicing API

* remove redundant ck namespace

* better gemm_gemm interface

* modify gemm_gemm

* add slice_tile api

* fix merge bug

* update to 3d padding, since we no longer need that much LDS size

* clean

* cleang

* clean

* clean

* clean

* clean

* clean

* clean

* clean

* clean

* clean

* clean

* clean

* clean

---------
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 1cf54e86
......@@ -84,6 +84,7 @@ int main(int argc, char* argv[])
constexpr ck::index_t kN0PerBlock = 128;
constexpr ck::index_t kK0PerBlock = 32;
constexpr ck::index_t kN1PerBlock = 128;
constexpr ck::index_t kK1PerBlock = 32;
constexpr ck::index_t kBlockSize = 256;
ck::index_t kGridSize = (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
......@@ -107,7 +108,8 @@ int main(int argc, char* argv[])
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock>{},
kN1PerBlock,
kK1PerBlock>{},
kGridSize,
kBlockSize,
0,
......
......@@ -11,6 +11,7 @@
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/slice_tile.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
......@@ -29,9 +30,18 @@ template <typename A0DataType,
ck::index_t kM0PerBlock,
ck::index_t kN0PerBlock,
ck::index_t kK0PerBlock,
ck::index_t kN1PerBlock>
ck::index_t kN1PerBlock,
ck::index_t kK1PerBlock>
struct GemmGemm
{
static constexpr auto I0 = ck::Number<0>{};
static constexpr auto BlockSize = ck::Number<kBlockSize>{};
static constexpr auto M0PerBlock = ck::Number<kM0PerBlock>{};
static constexpr auto N0PerBlock = ck::Number<kN0PerBlock>{};
static constexpr auto K0PerBlock = ck::Number<kK0PerBlock>{};
static constexpr auto N1PerBlock = ck::Number<kN1PerBlock>{};
static constexpr auto K1PerBlock = ck::Number<kK1PerBlock>{};
// block gemm0 pipeline
using BlockGemm0Pipeline = ck::tile_program::block::BlockGemmPipelineAGmemBGmemCRegV2<
ck::tile_program::block::BlockGemmPipelineProblem<
......@@ -49,7 +59,7 @@ struct GemmGemm
B1DataType,
Acc1DataType,
kBlockSize,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kN0PerBlock>>,
ck::tile_program::TileGemmShape<kM0PerBlock, kN1PerBlock, kK1PerBlock>>,
ck::tile_program::block::BlockGemmARegBSmemCRegV1DefaultPolicy>;
#if 0
......@@ -59,23 +69,51 @@ struct GemmGemm
using namespace ck;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_block_desc;
}
#elif 1
// 3d, with padding
__device__ static constexpr auto MakeB1LdsBlockDescriptor()
{
using namespace ck;
// using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t kPad = 1;
constexpr index_t kK1 = 8;
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(Number<kKPerBlock / kK1>{}, Number<kNPerBlock>{}, Number<kK1>{}),
make_tuple(Number<(kNPerBlock + kPad) * kK1>{}, Number<kK1>{}, Number<1>{}),
Number<kK1>{},
Number<1>{});
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_0,
make_tuple(make_pass_through_transform(kNPerBlock),
make_merge_transform(make_tuple(Number<kKPerBlock / kK1>{}, Number<kK1>{}))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return b_lds_block_desc;
}
#else
// fake XOR
__device__ static constexpr auto MakeB1LdsBlockDescriptor()
__host__ __device__ static constexpr auto MakeB1LdsBlockDescriptor()
{
using namespace ck;
using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(kNPerBlock / 2, 2, kKPerBlock), Number<kKPerBlock>{});
......@@ -108,7 +146,7 @@ struct GemmGemm
using BDataType = B1DataType;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr index_t kKPerBlock = kK1PerBlock;
constexpr index_t K1 = 16 / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1;
......@@ -138,14 +176,14 @@ struct GemmGemm
const B0DataType* p_b0,
const B1DataType* p_b1,
C1DataType* p_c1,
ck::index_t M0,
ck::index_t N0,
ck::index_t K0,
ck::index_t N1,
ck::index_t Lda0,
ck::index_t Ldb0,
ck::index_t Ldb1,
ck::index_t Ldc1)
const ck::index_t M0,
const ck::index_t N0,
const ck::index_t K0,
const ck::index_t N1,
const ck::index_t Lda0,
const ck::index_t Ldb0,
const ck::index_t Ldb1,
const ck::index_t Ldc1)
{
using namespace ck;
using namespace ck::tile_program;
......@@ -177,38 +215,40 @@ struct GemmGemm
__shared__ char p_smem_char[GetStaticLdsSize()];
// A0 DRAM block window
auto a0_dram_block_window = make_tile_window(
a0_dram_grid, make_tuple(Number<kM0PerBlock>{}, Number<kK0PerBlock>{}), {iM0, 0});
auto a0_dram_block_window =
make_tile_window(a0_dram_grid, make_tuple(M0PerBlock, K0PerBlock), {iM0, 0});
// B0 DRAM block window
auto b0_dram_block_window = make_tile_window(
b0_dram_grid, make_tuple(Number<kN0PerBlock>{}, Number<kK0PerBlock>{}), {0, 0});
auto b0_dram_block_window =
make_tile_window(b0_dram_grid, make_tuple(N0PerBlock, K0PerBlock), {0, 0});
// Block GEMM0 pipeline
constexpr auto block_gemm0_pipeline = BlockGemm0Pipeline{};
// B1 DRAM window
auto b1_dram_block_window =
make_tile_window(b1_dram_grid,
make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}),
{iN1, 0},
MakeB1DramTileDistribution());
auto b1_dram_block_window = make_tile_window(b1_dram_grid,
make_tuple(N1PerBlock, K1PerBlock),
{iN1, 0},
MakeB1DramTileDistribution());
// B1 LDS tensor view: occupies the same LDS allocation as block_gemm0_pipeline
auto b1_lds_block = make_tensor_view<AddressSpaceEnum::Lds>(
reinterpret_cast<B1DataType*>(p_smem_char), MakeB1LdsBlockDescriptor());
auto b1_lds_block_window = make_tile_window(
b1_lds_block, make_tuple(Number<kN1PerBlock>{}, Number<kN0PerBlock>{}), {0, 0});
auto b1_lds_block_window =
make_tile_window(b1_lds_block, make_tuple(N1PerBlock, K1PerBlock), {0, 0});
// Bock GEMM1
constexpr auto block_gemm1 = BlockGemm1{};
// Acc1 tile
auto acc1_block_tile = decltype(block_gemm1(
tile_elementwise_in(
type_convert<C0DataType, Acc0DataType>,
block_gemm0_pipeline(a0_dram_block_window, b0_dram_block_window, 0, nullptr)),
get_slice_tile(
tile_elementwise_in(
type_convert<C0DataType, Acc0DataType>,
block_gemm0_pipeline(a0_dram_block_window, b0_dram_block_window, 0, nullptr)),
Sequence<0, 0>{},
Sequence<kM0PerBlock, kK1PerBlock>{}),
b1_dram_block_window)){};
// init Acc1
......@@ -226,30 +266,44 @@ struct GemmGemm
const auto c0_block_tile =
tile_elementwise_in(type_convert<C0DataType, Acc0DataType>, acc0_block_tile);
// Block GEMM1: acc1 += c0 * b1
{
// load b1
const auto b1_block_tile = load_tile(b1_dram_block_window);
// wait for block gemm0 pipeline to finish
block_sync_lds();
// prefetch load b1
const auto b1_block_tile = load_tile(b1_dram_block_window);
move_tile_window(b1_dram_block_window, {0, kK1PerBlock});
store_tile(b1_lds_block_window, b1_block_tile);
block_sync_lds();
// wait for store_tile to finish
block_sync_lds();
store_tile(b1_lds_block_window, b1_block_tile);
// acc1 += c0 * b1
block_gemm1(acc1_block_tile, c0_block_tile, b1_lds_block_window);
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
// wait for block gemm1 to finish
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i) {
// acc1 += c0 * b1
const auto b1_block_tile_1 = load_tile(b1_dram_block_window);
block_sync_lds();
block_gemm1(acc1_block_tile,
get_slice_tile(c0_block_tile,
Sequence<0, i * kK1PerBlock>{},
Sequence<kM0PerBlock, (i + 1) * kK1PerBlock>{}),
b1_lds_block_window);
block_sync_lds();
move_tile_window(b1_dram_block_window, {0, kK1PerBlock});
store_tile(b1_lds_block_window, b1_block_tile_1);
});
}
// tail
{
block_sync_lds();
block_gemm1(acc1_block_tile,
get_slice_tile(c0_block_tile,
Sequence<0, (k1_loops - 1) * kK1PerBlock>{},
Sequence<kM0PerBlock, kN0PerBlock>{}),
b1_lds_block_window);
}
// move tile windows
move_tile_window(b0_dram_block_window, {kN0PerBlock, 0});
move_tile_window(b1_dram_block_window, {0, kN0PerBlock});
block_sync_lds();
iN0 += kN0PerBlock;
} while(iN0 < N0);
......@@ -262,11 +316,10 @@ struct GemmGemm
auto c1_dram_grid = make_naive_tensor_view<AddressSpaceEnum::Global>(
p_c1, make_tuple(M0, N1), make_tuple(Ldc1, 1), Number<32>{}, Number<1>{});
auto c1_dram_window =
make_tile_window(c1_dram_grid,
make_tuple(Number<kM0PerBlock>{}, Number<kN1PerBlock>{}),
{iM0, iN1},
c1_block_tile.GetTileDistribution());
auto c1_dram_window = make_tile_window(c1_dram_grid,
make_tuple(M0PerBlock, N1PerBlock),
{iM0, iN1},
c1_block_tile.GetTileDistribution());
store_tile(c1_dram_window, c1_block_tile);
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_window.hpp"
#include "ck/tile_program/tile/static_distributed_tensor.hpp"
namespace ck {
namespace tile_program {
namespace detail {
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<Sequence<x, xs...>,
Sequence<m, ms...>,
Sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<Sequence<xs...>, Sequence<ms...>, Sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::Front().value;
static constexpr auto slice_length =
std::conditional_t<m, Number<math::gcd(x, slice_size)>, Number<x>>::value;
using dim_lengths =
typename sequence_merge<Sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<Sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, Sequence<slice_size / slice_length>, Sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, Number<_flag>, Number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, Number<id>, Number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, Number<old_scan::split_idx>, Number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<Sequence<x>, Sequence<m>, Sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, Number<math::gcd(x, slice_size)>, Number<x>>::value;
using dim_lengths = Sequence<slice_length>;
using dim_slices = Sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, Sequence<slice_size / slice_length>, Sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.Front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, Number<_flag>, Number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, Number<id>, Number<0>>::value;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and Number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return Tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::Size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
Number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::Size(), 1>::type{})
{
static_assert(Seq::Size() == Mask::Size());
using sliced_type =
reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::Front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
Number<sliced_type::split_idx>{});
}
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// also, sliced along y_dim need be the first dim of current dim.
// Multiply Y dim before sliced dim does not make sense
//
// e.g
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
// totally 16 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
// |--> slice along this P dim, will split threads, not supported
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
// |--> slice along this Y dim, but this Y sim need to split into 2
// subdime
// the P dim in the left is 1, means actually not crossing P
//
template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
__host__ __device__ constexpr auto slice_distribution_from_x(
Distribution, Sequence<XSliceBegins...> x_slice_begins, Sequence<XSliceEnds...> x_slice_ends)
{
// NOTE: this function need to be called under constexpr context,
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
using Encoding = decltype(Distribution::GetStaticTileDistributionEncoding());
static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
constexpr auto src_h_prefix_sum = Encoding::Detail::GetHDimLengthsPrefixSum();
constexpr auto src_y_info = Encoding::Detail::GetSortedYInfo();
constexpr auto src_y_dims = src_y_info[Number<0>{}];
constexpr auto src_y_maps = src_y_info[Number<1>{}];
constexpr auto src_y_prefix_sum = src_y_info[Number<2>{}];
constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
{
auto y_slice_sorted_origins = make_zero_multi_index<Distribution::NDimY>();
auto y_slice_lengths =
to_array<index_t, Distribution::NDimY>(Distribution{}.GetYs2DDescriptor().GetLengths());
// This lambda will modify some value outside, so c++ will not treat return value as
// constexpr
// TODO: ugly
auto new_h_lengths = transform_tuples(
[&](auto h_len, auto id) {
constexpr auto sliced_h =
reverse_slice_sequence(h_len, Number<x_slice_lengths[id]>{});
constexpr auto sliced_h_lens = sliced_h[Number<0>{}];
constexpr auto sliced_h_index = sliced_h[Number<2>{}];
// update y_slice_lengths
constexpr auto uniformed_h_index = sliced_h_index + Number<src_h_prefix_sum[id]>{};
constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
static_assert(found_y_index >= 0 && found_y_index < src_y_dims.Size(),
"not sliced at y dim, please check");
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_slice_lengths(src_y_maps[found_y_index - i]) =
sliced_h_lens[sliced_h_index - i];
});
// TODO: add validations not across p dim
// NOTE: this y_origin is for all dims, not only current dim
// will later use pick to select target dim
constexpr auto y_origin = [&]() {
constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
h_trans.CalculateLowerIndex(h_origin_, Sequence<x_slice_begins[id].value>{});
auto y_origin_ = make_zero_multi_index<Distribution::NDimY>();
static_for<0, sliced_h_index + 1, 1>{}([&](auto i) {
y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
});
return y_origin_;
}();
constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
src_y_prefix_sum[id + 1],
1>::type{};
set_container_subset(
y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
return sliced_h_lens;
},
typename Encoding::HsLengthss{},
typename arithmetic_sequence_gen<0, Encoding::HsLengthss::Size(), 1>::type{});
auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
}
();
return sliced_hlen_yidx_ylen;
}
} // namespace detail
template <typename StaticDistributedTensor_, index_t... SliceBegins, index_t... SliceEnds>
__host__ __device__ constexpr auto get_slice_tile(const StaticDistributedTensor_& tile,
Sequence<SliceBegins...> slice_begins,
Sequence<SliceEnds...> slice_ends)
{
using Distribution = decltype(StaticDistributedTensor_::GetTileDistribution());
using Encoding = decltype(Distribution::GetStaticTileDistributionEncoding());
using DataType = typename StaticDistributedTensor_::DataType;
constexpr auto sliced_hlen_yidx_ylen =
detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends);
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[Number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[Number<1>{}];
constexpr auto sliced_y_origins_size = sliced_y_origins_array.Size();
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[Number<2>{}];
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.Size();
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
using SlicedEnc =
StaticTileDistributionEncoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the
// h_lengths type
typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor,
typename Encoding::Ys2RHsMinor>;
auto sliced_tensor =
make_static_distributed_tensor<DataType>(make_static_tile_distribution(SlicedEnc{}));
sliced_tensor.GetThreadBuffer() = tile.GetSlicedThreadData(sliced_y_origins, sliced_y_lengths);
return sliced_tensor;
}
template <typename DstStaticDistributedTensor_,
typename SrcStaticDistributedTensor_,
index_t... SliceBegins,
index_t... SliceEnds>
__host__ __device__ constexpr auto set_slice_tile(DstStaticDistributedTensor_& dst_tile,
const SrcStaticDistributedTensor_& src_tile,
Sequence<SliceBegins...> slice_begins,
Sequence<SliceEnds...> slice_ends)
{
using DstDistribution = decltype(DstStaticDistributedTensor_::GetTileDistribution());
constexpr auto sliced_hlen_yidx_ylen =
detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends);
constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[Number<0>{}];
constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[Number<1>{}];
constexpr auto sliced_y_origins_size = sliced_y_origins_array.Size();
constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[Number<2>{}];
constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.Size();
constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.GetThreadBuffer());
}
} // namespace tile_program
} // namespace ck
......@@ -41,6 +41,7 @@ struct StaticTileDistributionEncoding
static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
// redundant but useful info
// TODO: really bad code, should be over-hauled
struct Detail
{
// ndim_rh_major_, ndim_span_mainor_
......@@ -232,6 +233,62 @@ struct StaticTileDistributionEncoding
}
}();
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
__host__ __device__ static constexpr auto GetHDimLengthsPrefixSum()
{
// <len_d0, len_d1, ...>
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
[&](auto i) {
constexpr index_t size = HsLengthss{}[i].Size();
return Number<size>{};
},
Number<NDimX>{});
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
return h_dim_prefix_sum;
}
__host__ __device__ static constexpr auto GetUniformedIdxY2H()
{
constexpr auto all_ys_2_rhss = transform_sequences(
[](auto major, auto minor) constexpr {
// <0, 0, len_d0, len_d0+len_d1, ...>
constexpr auto x_dim_prefix_sum =
merge_sequences(Sequence<0>{} /*for R dims*/, GetHDimLengthsPrefixSum());
return x_dim_prefix_sum.At(major) + minor;
},
Ys2RHsMajor{},
Ys2RHsMinor{});
return all_ys_2_rhss;
}
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
template <typename IdxSeq, typename PrefixSumSeq>
__host__ __device__ static constexpr auto GetSortedInfo(IdxSeq, PrefixSumSeq)
{
using sorted_idx =
sequence_unique_sort<IdxSeq, math::less<index_t>, math::equal<index_t>>;
constexpr auto sorted_dims = typename sorted_idx::type{};
constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
constexpr auto sorted_histogram =
histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
}
__host__ __device__ static constexpr auto GetSortedYInfo()
{
return GetSortedInfo(GetUniformedIdxY2H(), GetHDimLengthsPrefixSum());
}
__host__ __device__ void Print() const
{
printf("StaticTileDistributionEncoding::Detail{");
......
......@@ -447,6 +447,20 @@ __host__ __device__ constexpr void set_container_subset(Y& y, Sequence<Is...> pi
}
}
// return the index of first occurance in the sequence.
// return seq.Size(), if not found
template <index_t... Is>
constexpr index_t container_find(Sequence<Is...> seq, index_t value)
{
for(auto i = 0; i < seq.Size(); i++)
{
if(seq[i] == value)
return i;
}
return seq.Size();
}
template <index_t... Is>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP
#pragma once
#include "integral_constant.hpp"
#include "ck/utility/integral_constant.hpp"
namespace ck {
......@@ -15,4 +14,3 @@ template <index_t N>
using LongNumber = integral_constant<long_index_t, N>;
} // namespace ck
#endif
......@@ -809,6 +809,46 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
}
// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5>, Init=0, Reduce=Add
// ResultSeq TargetSeq Reduce
template <typename, typename, typename>
struct sequence_exclusive_scan;
template <index_t... Xs, index_t Y, index_t... Ys, typename Reduce>
struct sequence_exclusive_scan<Sequence<Xs...>, Sequence<Y, Ys...>, Reduce>
{
using old_scan = typename sequence_merge<Sequence<Xs...>,
Sequence<Reduce{}(Y, Sequence<Xs...>{}.Back())>>::type;
using type = typename sequence_exclusive_scan<old_scan, Sequence<Ys...>, Reduce>::type;
};
template <index_t... Xs, index_t Y, typename Reduce>
struct sequence_exclusive_scan<Sequence<Xs...>, Sequence<Y>, Reduce>
{
using type = Sequence<Xs...>;
};
template <index_t... Xs, typename Reduce>
struct sequence_exclusive_scan<Sequence<Xs...>, Sequence<>, Reduce>
{
using type = Sequence<Xs...>;
};
template <typename Seq, typename Reduce, index_t Init>
constexpr auto exclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
// TODO: c++20 and later can pass in Reduce with a lambda expression
return typename sequence_exclusive_scan<Sequence<Init>, Seq, Reduce>::type{};
}
template <typename Seq>
constexpr auto prefix_sum_sequence(Seq)
{
return typename sequence_exclusive_scan<Sequence<0>,
typename sequence_merge<Seq, Sequence<0>>::type,
math::plus<index_t>>::type{};
}
template <typename Seq, index_t... Is>
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
{
......
......@@ -3,7 +3,10 @@
#pragma once
#include "ck/utility/sequence.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/macro_func_array_to_sequence.hpp"
namespace ck {
......@@ -14,7 +17,7 @@ __host__ __device__ constexpr auto make_sequence(Number<Is>...)
}
// F() returns index_t
// F use default constructor
// F use default constructor, so F cannot be lambda function
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence(F, Number<N>)
{
......@@ -22,6 +25,7 @@ __host__ __device__ constexpr auto generate_sequence(F, Number<N>)
}
// F() returns Number<>
// F could be lambda function
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
{
......@@ -35,4 +39,55 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
return Sequence<Is...>{};
}
namespace detail {
template <index_t h_idx, typename SeqSortedSamples, typename SeqRange>
struct sorted_sequence_histogram;
template <index_t h_idx, index_t x, index_t... xs, index_t r, index_t... rs>
struct sorted_sequence_histogram<h_idx, Sequence<x, xs...>, Sequence<r, rs...>>
{
template <typename Histogram>
constexpr auto operator()(Histogram& h)
{
if constexpr(x < r)
{
h.template At<h_idx>() += 1;
sorted_sequence_histogram<h_idx, Sequence<xs...>, Sequence<r, rs...>>{}(h);
}
else
{
h.template At<h_idx + 1>() = 1;
sorted_sequence_histogram<h_idx + 1, Sequence<xs...>, Sequence<rs...>>{}(h);
}
}
};
template <index_t h_idx, index_t x, index_t r, index_t... rs>
struct sorted_sequence_histogram<h_idx, Sequence<x>, Sequence<r, rs...>>
{
template <typename Histogram>
constexpr auto operator()(Histogram& h)
{
if constexpr(x < r)
{
h.template At<h_idx>() += 1;
}
}
};
} // namespace detail
// SeqSortedSamples: <0, 2, 3, 5, 7>, SeqRange: <0, 3, 6, 9> -> SeqHistogram : <2, 2, 1>
template <typename SeqSortedSamples, index_t r, index_t... rs>
constexpr auto histogram_sorted_sequence(SeqSortedSamples, Sequence<r, rs...>)
{
constexpr auto bins = sizeof...(rs); // or categories
constexpr auto histogram = [&]() {
Array<index_t, bins> h{0}; // make sure this can clear all element to zero
detail::sorted_sequence_histogram<0, SeqSortedSamples, Sequence<rs...>>{}(h);
return h;
}();
return TO_SEQUENCE(histogram, bins);
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment